diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 848bfaa7..6af597e1 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{ - utils, Error, Result, Ruma, services, + utils, Error, Result, Ruma, services, api::client_server, }; use ruma::{ api::client::{ @@ -381,7 +381,7 @@ pub async fn deactivate_route( } // Make the user leave all rooms before deactivation - services().rooms.leave_all_rooms(&sender_user).await?; + client_server::leave_all_rooms(&sender_user).await?; // Remove devices and mark account as deactivated services().users.deactivate_account(sender_user)?; diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 7aa5fb2c..444cc15f 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -25,12 +25,12 @@ pub async fn create_alias_route( )); } - if services().rooms.id_from_alias(&body.room_alias)?.is_some() { + if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() { return Err(Error::Conflict("Alias already exists.")); } - services().rooms - .set_alias(&body.room_alias, Some(&body.room_id))?; + services().rooms.alias + .set_alias(&body.room_alias, &body.room_id)?; Ok(create_alias::v3::Response::new()) } @@ -51,7 +51,7 @@ pub async fn delete_alias_route( )); } - services().rooms.set_alias(&body.room_alias, None)?; + services().rooms.alias.remove_alias(&body.room_alias)?; // TODO: update alt_aliases? @@ -88,7 +88,7 @@ pub(crate) async fn get_alias_helper( } let mut room_id = None; - match services().rooms.id_from_alias(room_alias)? { + match services().rooms.alias.resolve_local_alias(room_alias)? { Some(r) => room_id = Some(r), None => { for (_id, registration) in services().appservice.all()? { @@ -115,7 +115,7 @@ pub(crate) async fn get_alias_helper( .await .is_ok() { - room_id = Some(services().rooms.id_from_alias(room_alias)?.ok_or_else(|| { + room_id = Some(services().rooms.alias.resolve_local_alias(room_alias)?.ok_or_else(|| { Error::bad_config("Appservice lied to us. Room does not exist.") })?); break; diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index 3551dcfd..c407c71e 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -29,16 +29,18 @@ pub async fn get_context_route( let base_pdu_id = services() .rooms + .timeline .get_pdu_id(&body.event_id)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "Base event id not found.", ))?; - let base_token = services().rooms.pdu_count(&base_pdu_id)?; + let base_token = services().rooms.timeline.pdu_count(&base_pdu_id)?; let base_event = services() .rooms + .timeline .get_pdu_from_id(&base_pdu_id)? .ok_or(Error::BadRequest( ErrorKind::NotFound, @@ -47,14 +49,14 @@ pub async fn get_context_route( let room_id = base_event.room_id.clone(); - if !services().rooms.is_joined(sender_user, &room_id)? { + if !services().rooms.state_cache.is_joined(sender_user, &room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - if !services().rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -68,6 +70,7 @@ pub async fn get_context_route( let events_before: Vec<_> = services() .rooms + .timeline .pdus_until(sender_user, &room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { @@ -79,7 +82,7 @@ pub async fn get_context_route( .collect(); for (_, event) in &events_before { - if !services().rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -92,7 +95,7 @@ pub async fn get_context_route( let start_token = events_before .last() - .and_then(|(pdu_id, _)| services().rooms.pdu_count(pdu_id).ok()) + .and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok()) .map(|count| count.to_string()); let events_before: Vec<_> = events_before @@ -102,6 +105,7 @@ pub async fn get_context_route( let events_after: Vec<_> = services() .rooms + .timeline .pdus_after(sender_user, &room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { @@ -113,7 +117,7 @@ pub async fn get_context_route( .collect(); for (_, event) in &events_after { - if !services().rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -124,7 +128,7 @@ pub async fn get_context_route( } } - let shortstatehash = match services().rooms.pdu_shortstatehash( + let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash( events_after .last() .map_or(&*body.event_id, |(_, e)| &*e.event_id), @@ -132,15 +136,16 @@ pub async fn get_context_route( Some(s) => s, None => services() .rooms - .current_shortstatehash(&room_id)? + .state + .get_room_shortstatehash(&room_id)? .expect("All rooms have state"), }; - let state_ids = services().rooms.state_full_ids(shortstatehash).await?; + let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; let end_token = events_after .last() - .and_then(|(pdu_id, _)| services().rooms.pdu_count(pdu_id).ok()) + .and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok()) .map(|count| count.to_string()); let events_after: Vec<_> = events_after @@ -151,10 +156,10 @@ pub async fn get_context_route( let mut state = Vec::new(); for (shortstatekey, id) in state_ids { - let (event_type, state_key) = services().rooms.get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let pdu = match services().rooms.get_pdu(&id)? { + let pdu = match services().rooms.timeline.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -163,7 +168,7 @@ pub async fn get_context_route( }; state.push(pdu.to_state_event()); } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let pdu = match services().rooms.get_pdu(&id)? { + let pdu = match services().rooms.timeline.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs index 87493fa0..2a60f672 100644 --- a/src/api/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -86,10 +86,10 @@ pub async fn set_room_visibility_route( match &body.visibility { room::Visibility::Public => { - services().rooms.set_public(&body.room_id, true)?; + services().rooms.directory.set_public(&body.room_id)?; info!("{} made {} public", sender_user, body.room_id); } - room::Visibility::Private => services().rooms.set_public(&body.room_id, false)?, + room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -108,7 +108,7 @@ pub async fn get_room_visibility_route( body: Ruma, ) -> Result { Ok(get_room_visibility::v3::Response { - visibility: if services().rooms.is_public_room(&body.room_id)? { + visibility: if services().rooms.directory.is_public_room(&body.room_id)? { room::Visibility::Public } else { room::Visibility::Private @@ -176,6 +176,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( let mut all_rooms: Vec<_> = services() .rooms + .directory .public_rooms() .map(|room_id| { let room_id = room_id?; @@ -183,6 +184,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( let chunk = PublicRoomsChunk { canonical_alias: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) @@ -193,6 +195,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( })?, name: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomName, "")? .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) @@ -203,6 +206,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( })?, num_joined_members: services() .rooms + .state_cache .room_joined_count(&room_id)? .unwrap_or_else(|| { warn!("Room {} has no member count", room_id); @@ -212,6 +216,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( .expect("user count should not be that big"), topic: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomTopic, "")? .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) @@ -222,6 +227,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( })?, world_readable: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) @@ -236,6 +242,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( })?, guest_can_join: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) @@ -248,6 +255,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( })?, avatar_url: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomAvatar, "")? .map(|s| { serde_json::from_str(s.content.get()) @@ -261,6 +269,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( .flatten(), join_rule: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? .map(|s| { serde_json::from_str(s.content.get()) diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 698bd1ec..4ce5d4c0 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -230,7 +230,7 @@ pub async fn get_key_changes_route( .filter_map(|r| r.ok()), ); - for room_id in services().rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) { + for room_id in services().rooms.state_cache.rooms_joined(sender_user).filter_map(|r| r.ok()) { device_list_updates.extend( services().users .keys_changed( diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index f0da0849..d6e8213c 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -99,7 +99,7 @@ pub async fn get_content_route( content_disposition, content_type, file, - }) = services().media.get(&mxc).await? + }) = services().media.get(mxc.clone()).await? { Ok(get_content::v3::Response { file, @@ -129,7 +129,7 @@ pub async fn get_content_as_filename_route( content_disposition: _, content_type, file, - }) = services().media.get(&mxc).await? + }) = services().media.get(mxc.clone()).await? { Ok(get_content_as_filename::v3::Response { file, @@ -165,7 +165,7 @@ pub async fn get_content_thumbnail_route( }) = services() .media .get_thumbnail( - &mxc, + mxc.clone(), body.width .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index b000ec1b..d6f820a7 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -30,7 +30,7 @@ use std::{ }; use tracing::{debug, error, warn}; -use crate::{services, PduEvent, service::pdu::{gen_event_id_canonical_json, PduBuilder}, Error, api::{server_server}, utils, Ruma}; +use crate::{Result, services, PduEvent, service::pdu::{gen_event_id_canonical_json, PduBuilder}, Error, api::{server_server, client_server}, utils, Ruma}; use super::get_alias_helper; @@ -48,6 +48,7 @@ pub async fn join_room_by_id_route( let mut servers = Vec::new(); // There is no body.server_name for /roomId/join servers.extend( services().rooms + .state_cache .invite_state(sender_user, &body.room_id)? .unwrap_or_default() .iter() @@ -88,6 +89,7 @@ pub async fn join_room_by_id_or_alias_route( let mut servers = body.server_name.clone(); servers.extend( services().rooms + .state_cache .invite_state(sender_user, &room_id)? .unwrap_or_default() .iter() @@ -131,7 +133,7 @@ pub async fn leave_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().rooms.leave_room(sender_user, &body.room_id).await?; + leave_room(sender_user, &body.room_id).await?; Ok(leave_room::v3::Response::new()) } @@ -162,6 +164,7 @@ pub async fn kick_user_route( let mut event: RoomMemberEventContent = serde_json::from_str( services().rooms + .state_accessor .room_state_get( &body.room_id, &StateEventType::RoomMember, @@ -189,7 +192,7 @@ pub async fn kick_user_route( ); let state_lock = mutex_state.lock().await; - services().rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -219,6 +222,7 @@ pub async fn ban_user_route( let event = services() .rooms + .state_accessor .room_state_get( &body.room_id, &StateEventType::RoomMember, @@ -255,7 +259,7 @@ pub async fn ban_user_route( ); let state_lock = mutex_state.lock().await; - services().rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -283,6 +287,7 @@ pub async fn unban_user_route( let mut event: RoomMemberEventContent = serde_json::from_str( services().rooms + .state_accessor .room_state_get( &body.room_id, &StateEventType::RoomMember, @@ -309,7 +314,7 @@ pub async fn unban_user_route( ); let state_lock = mutex_state.lock().await; - services().rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -340,7 +345,7 @@ pub async fn forget_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().rooms.forget(&body.room_id, sender_user)?; + services().rooms.state_cache.forget(&body.room_id, sender_user)?; Ok(forget_room::v3::Response::new()) } @@ -356,6 +361,7 @@ pub async fn joined_rooms_route( Ok(joined_rooms::v3::Response { joined_rooms: services() .rooms + .state_cache .rooms_joined(sender_user) .filter_map(|r| r.ok()) .collect(), @@ -373,7 +379,7 @@ pub async fn get_member_events_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // TODO: check history visibility? - if !services().rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -383,6 +389,7 @@ pub async fn get_member_events_route( Ok(get_member_events::v3::Response { chunk: services() .rooms + .state_accessor .room_state_full(&body.room_id) .await? .iter() @@ -403,7 +410,7 @@ pub async fn joined_members_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You aren't a member of the room.", @@ -411,7 +418,7 @@ pub async fn joined_members_route( } let mut joined = BTreeMap::new(); - for user_id in services().rooms.room_members(&body.room_id).filter_map(|r| r.ok()) { + for user_id in services().rooms.state_cache.room_members(&body.room_id).filter_map(|r| r.ok()) { let display_name = services().users.displayname(&user_id)?; let avatar_url = services().users.avatar_url(&user_id)?; @@ -446,7 +453,7 @@ async fn join_room_by_id_helper( let state_lock = mutex_state.lock().await; // Ask a remote server if we don't have this room - if !services().rooms.exists(room_id)? { + if !services().rooms.metadata.exists(room_id)? { let mut make_join_response_and_server = Err(Error::BadServerResponse( "No server available to assist in joining.", )); @@ -553,7 +560,7 @@ async fn join_room_by_id_helper( ) .await?; - services().rooms.get_or_create_shortroomid(room_id, &services().globals)?; + services().rooms.short.get_or_create_shortroomid(room_id)?; let parsed_pdu = PduEvent::from_id_val(event_id, join_event.clone()) .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; @@ -586,7 +593,7 @@ async fn join_room_by_id_helper( services().rooms.add_pdu_outlier(&event_id, &value)?; if let Some(state_key) = &pdu.state_key { - let shortstatekey = services().rooms.get_or_create_shortstatekey( + let shortstatekey = services().rooms.short.get_or_create_shortstatekey( &pdu.kind.to_string().into(), state_key, )?; @@ -594,7 +601,7 @@ async fn join_room_by_id_helper( } } - let incoming_shortstatekey = services().rooms.get_or_create_shortstatekey( + let incoming_shortstatekey = services().rooms.short.get_or_create_shortstatekey( &parsed_pdu.kind.to_string().into(), parsed_pdu .state_key @@ -606,6 +613,7 @@ async fn join_room_by_id_helper( let create_shortstatekey = services() .rooms + .short .get_shortstatekey(&StateEventType::RoomCreate, "")? .expect("Room exists"); @@ -613,7 +621,7 @@ async fn join_room_by_id_helper( return Err(Error::BadServerResponse("State contained no create event.")); } - services().rooms.force_state( + services().rooms.state.force_state( room_id, state .into_iter() @@ -780,7 +788,7 @@ pub(crate) async fn invite_helper<'a>( redacts: None, }, sender_user, room_id, &state_lock); - let invite_room_state = services().rooms.calculate_invite_state(&pdu)?; + let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; drop(state_lock); diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index 14affc65..f8d06023 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -87,7 +87,7 @@ pub async fn create_room_route( Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.") })?; - if services().rooms.id_from_alias(&alias)?.is_some() { + if services().rooms.alias.resolve_local_alias(&alias)?.is_some() { Err(Error::BadRequest( ErrorKind::RoomInUse, "Room alias already exists.", diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 4e8d594e..b2dfe2a7 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -246,7 +246,7 @@ async fn send_state_event_for_key_helper( if alias.server_name() != services().globals.server_name() || services() .rooms - .id_from_alias(&alias)? + .alias.resolve_local_alias(&alias)? .filter(|room| room == room_id) // Make sure it's the right room .is_none() { diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 776777d1..bacc1ac7 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1842,7 +1842,7 @@ pub async fn get_room_information_route( let room_id = services() .rooms - .id_from_alias(&body.room_alias)? + .alias.resolve_local_alias(&body.room_alias)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "Room alias not found.", diff --git a/src/service/account_data.rs b/src/database/key_value/account_data.rs similarity index 90% rename from src/service/account_data.rs rename to src/database/key_value/account_data.rs index 70ad9f2a..49c9170f 100644 --- a/src/service/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,17 +1,14 @@ -use crate::{utils, Error, Result}; -use ruma::{ - api::client::error::ErrorKind, - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; -use serde::{de::DeserializeOwned, Serialize}; -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; -impl AccountData { +use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw, RoomId}; +use serde::{Serialize, de::DeserializeOwned}; + +use crate::{Result, database::KeyValueDatabase, service, Error, utils, services}; + +impl service::account_data::Data for KeyValueDatabase { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - pub fn update( + fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, @@ -63,7 +60,7 @@ impl AccountData { /// Searches the account data for a specific kind. #[tracing::instrument(skip(self, room_id, user_id, kind))] - pub fn get( + fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, @@ -96,7 +93,7 @@ impl AccountData { /// Returns all changes to the account data that happened after `since`. #[tracing::instrument(skip(self, room_id, user_id, since))] - pub fn changes_since( + fn changes_since( &self, room_id: Option<&RoomId>, user_id: &UserId, diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index eae2cfbc..edb027e9 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,4 +1,4 @@ -use crate::{database::KeyValueDatabase, service, utils, Error}; +use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller @@ -54,7 +54,7 @@ impl service::appservice::Data for KeyValueDatabase { ) } - fn iter_ids(&self) -> Result> + '_> { + fn iter_ids(&self) -> Result>>> { Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { utils::string_from_bytes(&id) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs new file mode 100644 index 00000000..81e6ee1f --- /dev/null +++ b/src/database/key_value/globals.rs @@ -0,0 +1,40 @@ +use ruma::signatures::Ed25519KeyPair; + +use crate::{Result, service, database::KeyValueDatabase, Error, utils}; + +impl service::globals::Data for KeyValueDatabase { + fn load_keypair(&self) -> Result { + let keypair_bytes = self.globals.get(b"keypair")?.map_or_else( + || { + let keypair = utils::generate_keypair(); + self.globals.insert(b"keypair", &keypair)?; + Ok::<_, Error>(keypair) + }, + |s| Ok(s.to_vec()), + )?; + + let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); + + let keypair = utils::string_from_bytes( + // 1. version + parts + .next() + .expect("splitn always returns at least one element"), + ) + .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) + .and_then(|version| { + // 2. key + parts + .next() + .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) + .map(|key| (version, key)) + }) + .and_then(|(version, key)| { + Ed25519KeyPair::from_der(key, version) + .map_err(|_| Error::bad_database("Private or public keys are invalid.")) + }); + } + fn remove_keypair(&self) -> Result<()> { + self.globals.remove(b"keypair")? + } +} diff --git a/src/service/key_backups.rs b/src/database/key_value/key_backups.rs similarity index 92% rename from src/service/key_backups.rs rename to src/database/key_value/key_backups.rs index be1d6b18..8171451c 100644 --- a/src/service/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,16 +1,11 @@ -use crate::{utils, Error, Result, services}; -use ruma::{ - api::client::{ - backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - error::ErrorKind, - }, - serde::Raw, - RoomId, UserId, -}; -use std::{collections::BTreeMap, sync::Arc}; +use std::collections::BTreeMap; -impl KeyBackups { - pub fn create_backup( +use ruma::{UserId, serde::Raw, api::client::{backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind}, RoomId}; + +use crate::{Result, service, database::KeyValueDatabase, services, Error, utils}; + +impl service::key_backups::Data for KeyValueDatabase { + fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw, @@ -30,7 +25,7 @@ impl KeyBackups { Ok(version) } - pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(version.as_bytes()); @@ -47,7 +42,7 @@ impl KeyBackups { Ok(()) } - pub fn update_backup( + fn update_backup( &self, user_id: &UserId, version: &str, @@ -71,7 +66,7 @@ impl KeyBackups { Ok(version.to_owned()) } - pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { + fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); let mut last_possible_key = prefix.clone(); @@ -92,7 +87,7 @@ impl KeyBackups { .transpose() } - pub fn get_latest_backup( + fn get_latest_backup( &self, user_id: &UserId, ) -> Result)>> { @@ -123,7 +118,7 @@ impl KeyBackups { .transpose() } - pub fn get_backup( + fn get_backup( &self, user_id: &UserId, version: &str, @@ -140,7 +135,7 @@ impl KeyBackups { }) } - pub fn add_key( + fn add_key( &self, user_id: &UserId, version: &str, @@ -173,7 +168,7 @@ impl KeyBackups { Ok(()) } - pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { + fn count_keys(&self, user_id: &UserId, version: &str) -> Result { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); prefix.extend_from_slice(version.as_bytes()); @@ -181,7 +176,7 @@ impl KeyBackups { Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) } - pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result { + fn get_etag(&self, user_id: &UserId, version: &str) -> Result { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(version.as_bytes()); @@ -196,7 +191,7 @@ impl KeyBackups { .to_string()) } - pub fn get_all( + fn get_all( &self, user_id: &UserId, version: &str, @@ -252,7 +247,7 @@ impl KeyBackups { Ok(rooms) } - pub fn get_room( + fn get_room( &self, user_id: &UserId, version: &str, @@ -289,7 +284,7 @@ impl KeyBackups { .collect()) } - pub fn get_session( + fn get_session( &self, user_id: &UserId, version: &str, @@ -314,7 +309,7 @@ impl KeyBackups { .transpose() } - pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(version.as_bytes()); @@ -327,7 +322,7 @@ impl KeyBackups { Ok(()) } - pub fn delete_room_keys( + fn delete_room_keys( &self, user_id: &UserId, version: &str, @@ -347,7 +342,7 @@ impl KeyBackups { Ok(()) } - pub fn delete_room_key( + fn delete_room_key( &self, user_id: &UserId, version: &str, diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs new file mode 100644 index 00000000..90a5c590 --- /dev/null +++ b/src/database/key_value/media.rs @@ -0,0 +1,66 @@ +use crate::{database::KeyValueDatabase, service, Error, utils, Result}; + +impl service::media::Data for KeyValueDatabase { + fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: &Option<&str>, content_type: &Option<&str>) -> Result> { + let mut key = mxc.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&width.to_be_bytes()); + key.extend_from_slice(&height.to_be_bytes()); + key.push(0xff); + key.extend_from_slice( + content_disposition + .as_ref() + .map(|f| f.as_bytes()) + .unwrap_or_default(), + ); + key.push(0xff); + key.extend_from_slice( + content_type + .as_ref() + .map(|c| c.as_bytes()) + .unwrap_or_default(), + ); + + self.mediaid_file.insert(&key, &[])?; + + Ok(key) + } + + fn search_file_metadata(&self, mxc: String, width: u32, height: u32) -> Result<(Option, Option, Vec)> { + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail + prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail + prefix.push(0xff); + + let (key, _) = self.mediaid_file.scan_prefix(prefix).next().ok_or(Error::NotFound)?; + + let mut parts = key.rsplit(|&b| b == 0xff); + + let content_type = parts + .next() + .map(|bytes| { + utils::string_from_bytes(bytes).map_err(|_| { + Error::bad_database("Content type in mediaid_file is invalid unicode.") + }) + }) + .transpose()?; + + let content_disposition_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; + + let content_disposition = if content_disposition_bytes.is_empty() { + None + } else { + Some( + utils::string_from_bytes(content_disposition_bytes).map_err(|_| { + Error::bad_database( + "Content Disposition in mediaid_file is invalid unicode.", + ) + })?, + ) + }; + Ok((content_disposition, content_type, key)) + } +} diff --git a/src/database/key_value/mod.rs b/src/database/key_value/mod.rs index 189571f6..efb85509 100644 --- a/src/database/key_value/mod.rs +++ b/src/database/key_value/mod.rs @@ -1,9 +1,9 @@ -//mod account_data; +mod account_data; //mod admin; mod appservice; -//mod globals; -//mod key_backups; -//mod media; +mod globals; +mod key_backups; +mod media; //mod pdu; mod pusher; mod rooms; diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index b77170db..35c84638 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -1,6 +1,6 @@ use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; -use crate::{service, database::KeyValueDatabase, Error}; +use crate::{service, database::KeyValueDatabase, Error, Result}; impl service::pusher::Data for KeyValueDatabase { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { @@ -51,7 +51,7 @@ impl service::pusher::Data for KeyValueDatabase { fn get_pusher_senderkeys<'a>( &'a self, sender: &UserId, - ) -> impl Iterator> + 'a { + ) -> Box>> { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index a9236a75..c762defa 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,12 +1,12 @@ use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind}; -use crate::{service, database::KeyValueDatabase, utils, Error, services}; +use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; impl service::rooms::alias::Data for KeyValueDatabase { fn set_alias( &self, alias: &RoomAliasId, - room_id: Option<&RoomId> + room_id: &RoomId ) -> Result<()> { self.alias_roomid .insert(alias.alias().as_bytes(), room_id.as_bytes())?; @@ -41,7 +41,7 @@ impl service::rooms::alias::Data for KeyValueDatabase { fn resolve_local_alias( &self, alias: &RoomAliasId - ) -> Result<()> { + ) -> Result>> { self.alias_roomid .get(alias.alias().as_bytes())? .map(|bytes| { @@ -56,7 +56,7 @@ impl service::rooms::alias::Data for KeyValueDatabase { fn local_aliases_for_room( &self, room_id: &RoomId, - ) -> Result<()> { + ) -> Result>> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 57dbb147..585d5626 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -1,5 +1,9 @@ -impl service::room::auth_chain::Data for KeyValueDatabase { - fn get_cached_eventid_authchain<'a>() -> Result> { +use std::{collections::HashSet, mem::size_of}; + +use crate::{service, database::KeyValueDatabase, Result, utils}; + +impl service::rooms::auth_chain::Data for KeyValueDatabase { + fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result> { self.shorteventid_authchain .get(&shorteventid.to_be_bytes())? .map(|chain| { @@ -12,8 +16,8 @@ impl service::room::auth_chain::Data for KeyValueDatabase { }) } - fn cache_eventid_authchain<'a>(shorteventid: u64, auth_chain: &HashSet) -> Result<()> { - shorteventid_authchain.insert( + fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet) -> Result<()> { + self.shorteventid_authchain.insert( &shorteventid.to_be_bytes(), &auth_chain .iter() diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index 44a580c3..c48afa9a 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -1,6 +1,6 @@ use ruma::RoomId; -use crate::{service, database::KeyValueDatabase, utils, Error}; +use crate::{service, database::KeyValueDatabase, utils, Error, Result}; impl service::rooms::directory::Data for KeyValueDatabase { fn set_public(&self, room_id: &RoomId) -> Result<()> { @@ -15,7 +15,7 @@ impl service::rooms::directory::Data for KeyValueDatabase { Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) } - fn public_rooms(&self) -> impl Iterator>> + '_ { + fn public_rooms(&self) -> Box>>> { self.publicroomids.iter().map(|(bytes, _)| { RoomId::parse( utils::string_from_bytes(&bytes).map_err(|_| { diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs index 9ffd33da..b5007f89 100644 --- a/src/database/key_value/rooms/edus/mod.rs +++ b/src/database/key_value/rooms/edus/mod.rs @@ -1,3 +1,7 @@ mod presence; mod typing; mod read_receipt; + +use crate::{service, database::KeyValueDatabase}; + +impl service::rooms::edus::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 9f3977db..fbbbff55 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt}; -use crate::{service, database::KeyValueDatabase, utils, Error, services}; +use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; impl service::rooms::edus::presence::Data for KeyValueDatabase { fn update_presence( @@ -56,8 +56,8 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase { fn get_presence_event( &self, - user_id: &UserId, room_id: &RoomId, + user_id: &UserId, count: u64, ) -> Result> { let mut presence_id = room_id.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index 68aea165..42d250f7 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -2,7 +2,7 @@ use std::mem; use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject}; -use crate::{database::KeyValueDatabase, service, utils, Error, services}; +use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { fn readreceipt_update( @@ -50,13 +50,13 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { &'a self, room_id: &RoomId, since: u64, - ) -> impl Iterator< + ) -> Box, u64, Raw, )>, - > + 'a { + >> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); let prefix2 = prefix.clone(); diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs index 905bffc8..b7d35968 100644 --- a/src/database/key_value/rooms/edus/typing.rs +++ b/src/database/key_value/rooms/edus/typing.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use ruma::{UserId, RoomId}; -use crate::{database::KeyValueDatabase, service, utils, Error, services}; +use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; impl service::rooms::edus::typing::Data for KeyValueDatabase { fn typing_add( @@ -79,7 +79,7 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase { fn typings_all( &self, room_id: &RoomId, - ) -> Result> { + ) -> Result>> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index c230cbf7..aaf14dd3 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -1,6 +1,6 @@ use ruma::{UserId, DeviceId, RoomId}; -use crate::{service, database::KeyValueDatabase}; +use crate::{service, database::KeyValueDatabase, Result}; impl service::rooms::lazy_loading::Data for KeyValueDatabase { fn lazy_load_was_sent_before( diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index b4cba2c6..0509cbb8 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,6 +1,6 @@ use ruma::RoomId; -use crate::{service, database::KeyValueDatabase}; +use crate::{service, database::KeyValueDatabase, Result}; impl service::rooms::metadata::Data for KeyValueDatabase { fn exists(&self, room_id: &RoomId) -> Result { diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs index adb810ba..406943ed 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/database/key_value/rooms/mod.rs @@ -1,16 +1,20 @@ mod alias; +mod auth_chain; mod directory; mod edus; -//mod event_handler; mod lazy_load; mod metadata; mod outlier; mod pdu_metadata; mod search; -//mod short; +mod short; mod state; mod state_accessor; mod state_cache; mod state_compressor; mod timeline; mod user; + +use crate::{database::KeyValueDatabase, service}; + +impl service::rooms::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index 08299a0c..aa975449 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -1,6 +1,6 @@ use ruma::{EventId, signatures::CanonicalJsonObject}; -use crate::{service, database::KeyValueDatabase, PduEvent, Error}; +use crate::{service, database::KeyValueDatabase, PduEvent, Error, Result}; impl service::rooms::outlier::Data for KeyValueDatabase { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index 602f3f6c..f3ac414f 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use ruma::{RoomId, EventId}; -use crate::{service, database::KeyValueDatabase}; +use crate::{service, database::KeyValueDatabase, Result}; impl service::rooms::pdu_metadata::Data for KeyValueDatabase { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 44663ff3..15937f6d 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -2,10 +2,10 @@ use std::mem::size_of; use ruma::RoomId; -use crate::{service, database::KeyValueDatabase, utils}; +use crate::{service, database::KeyValueDatabase, utils, Result}; impl service::rooms::search::Data for KeyValueDatabase { - fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()> { + fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: u64, message_body: String) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) @@ -26,7 +26,7 @@ impl service::rooms::search::Data for KeyValueDatabase { &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a, Vec)>> { + ) -> Result>>, Vec)>> { let prefix = self .get_shortroomid(room_id)? .expect("room exists") diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs new file mode 100644 index 00000000..91296385 --- /dev/null +++ b/src/database/key_value/rooms/short.rs @@ -0,0 +1,4 @@ +use crate::{database::KeyValueDatabase, service}; + +impl service::rooms::short::Data for KeyValueDatabase { +} diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index 192dbb83..405939dd 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::{sync::MutexGuard, collections::HashSet}; use std::fmt::Debug; -use crate::{service, database::KeyValueDatabase, utils, Error}; +use crate::{service, database::KeyValueDatabase, utils, Error, Result}; impl service::rooms::state::Data for KeyValueDatabase { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { @@ -24,7 +24,7 @@ impl service::rooms::state::Data for KeyValueDatabase { Ok(()) } - fn set_event_state(&self, shorteventid: Vec, shortstatehash: Vec) -> Result<()> { + fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { self.shorteventid_shortstatehash .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index ea15afc0..037b98fc 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -1,6 +1,6 @@ use std::{collections::{BTreeMap, HashMap}, sync::Arc}; -use crate::{database::KeyValueDatabase, service, PduEvent, Error, utils}; +use crate::{database::KeyValueDatabase, service, PduEvent, Error, utils, Result}; use async_trait::async_trait; use ruma::{EventId, events::StateEventType, RoomId}; diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 567dc809..5f054858 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,6 +1,6 @@ -use ruma::{UserId, RoomId}; +use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw}; -use crate::{service, database::KeyValueDatabase}; +use crate::{service, database::KeyValueDatabase, services, Result}; impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { @@ -9,4 +9,70 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { userroom_id.extend_from_slice(room_id.as_bytes()); self.roomuseroncejoinedids.insert(&userroom_id, &[]) } + + fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_joined.insert(&userroom_id, &[])?; + self.roomuserid_joined.insert(&roomuser_id, &[])?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + fn mark_as_invited(&self, user_id: &UserId, room_id: &RoomId, last_state: Option>>) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_invitestate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()) + .expect("state to bytes always works"), + )?; + self.roomuserid_invitecount + .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_leftstate.insert( + &userroom_id, + &serde_json::to_vec(&Vec::>::new()).unwrap(), + )?; // TODO + self.roomuserid_leftcount + .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + + Ok(()) + } } diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index 09e35660..23a7122b 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, mem::size_of}; -use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils}; +use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils, Result}; impl service::rooms::state_compressor::Data for KeyValueDatabase { fn get_statediff(&self, shortstatehash: u64) -> Result { diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index cf93df12..c42509e0 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -3,7 +3,7 @@ use std::{collections::hash_map, mem::size_of, sync::Arc}; use ruma::{UserId, RoomId, api::client::error::ErrorKind, EventId, signatures::CanonicalJsonObject}; use tracing::error; -use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent}; +use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result}; impl service::rooms::timeline::Data for KeyValueDatabase { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { @@ -190,7 +190,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, since: u64, - ) -> Result, PduEvent)>> + 'a> { + ) -> Result, PduEvent)>>>> { let prefix = self .get_shortroomid(room_id)? .expect("room exists") @@ -224,7 +224,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> Result, PduEvent)>> + 'a> { + ) -> Result, PduEvent)>>>> { // Create the first part of the full pdu id let prefix = self .get_shortroomid(room_id)? @@ -258,7 +258,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> Result, PduEvent)>> + 'a> { + ) -> Result, PduEvent)>>>> { // Create the first part of the full pdu id let prefix = self .get_shortroomid(room_id)? diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 2fc3b9f4..d49bc1d7 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,6 +1,6 @@ use ruma::{UserId, RoomId}; -use crate::{service, database::KeyValueDatabase, utils, Error}; +use crate::{service, database::KeyValueDatabase, utils, Error, Result}; impl service::rooms::user::Data for KeyValueDatabase { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { @@ -78,7 +78,7 @@ impl service::rooms::user::Data for KeyValueDatabase { fn get_shared_rooms<'a>( &'a self, users: Vec>, - ) -> Result>> + 'a> { + ) -> Result>>>> { let iterators = users.into_iter().map(move |user_id| { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index 6652a627..a63b3c5d 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,6 +1,6 @@ use ruma::{UserId, DeviceId, TransactionId}; -use crate::{service, database::KeyValueDatabase}; +use crate::{service, database::KeyValueDatabase, Result}; impl service::transaction_ids::Data for KeyValueDatabase { fn add_txnid( diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index b1960bd5..cf242dec 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -1,8 +1,6 @@ -use std::io::ErrorKind; +use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}}; -use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::uiaa::UiaaInfo}; - -use crate::{database::KeyValueDatabase, service, Error}; +use crate::{database::KeyValueDatabase, service, Error, Result}; impl service::uiaa::Data for KeyValueDatabase { fn set_uiaa_request( diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index ea844903..82e3bac6 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -3,7 +3,7 @@ use std::{mem::size_of, collections::BTreeMap}; use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt}; use tracing::warn; -use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services}; +use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services, Result}; impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. @@ -56,7 +56,7 @@ impl service::users::Data for KeyValueDatabase { } /// Returns an iterator over all users on this homeserver. - fn iter(&self) -> impl Iterator>> + '_ { + fn iter(&self) -> Box>>> { self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("User ID in userid_password is invalid unicode.") @@ -270,7 +270,7 @@ impl service::users::Data for KeyValueDatabase { fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator>> + 'a { + ) -> Box>>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); // All devices have metadata @@ -608,7 +608,7 @@ impl service::users::Data for KeyValueDatabase { user_or_room_id: &str, from: u64, to: Option, - ) -> impl Iterator>> + 'a { + ) -> Box>>> { let mut prefix = user_or_room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -878,7 +878,7 @@ impl service::users::Data for KeyValueDatabase { fn all_devices_metadata<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator> + 'a { + ) -> Box>> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); diff --git a/src/database/mod.rs b/src/database/mod.rs index 12758af2..4ea619a8 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,7 +1,7 @@ pub mod abstraction; pub mod key_value; -use crate::{utils, Config, Error, Result, service::{users, globals, uiaa, rooms, account_data, media, key_backups, transaction_ids, sending, admin::{self, create_admin_room}, appservice, pusher}}; +use crate::{utils, Config, Error, Result, service::{users, globals, uiaa, rooms, account_data, media, key_backups, transaction_ids, sending, appservice, pusher}}; use abstraction::KeyValueDatabaseEngine; use directories::ProjectDirs; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -253,7 +253,7 @@ impl KeyValueDatabase { let (admin_sender, admin_receiver) = mpsc::unbounded_channel(); let (sending_sender, sending_receiver) = mpsc::unbounded_channel(); - let db = Arc::new(TokioRwLock::from(Self { + let db = Self { _db: builder.clone(), userid_password: builder.open_tree("userid_password")?, userid_displayname: builder.open_tree("userid_displayname")?, @@ -345,10 +345,9 @@ impl KeyValueDatabase { senderkey_pusher: builder.open_tree("senderkey_pusher")?, global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, - })); + }; // TODO: do this after constructing the db - let guard = db.read().await; // Matrix resource ownership is based on the server name; changing it // requires recreating the database from scratch. diff --git a/src/lib.rs b/src/lib.rs index 0d058df3..c6e65697 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ mod service; pub mod api; mod utils; -use std::cell::Cell; +use std::{cell::Cell, sync::RwLock}; pub use config::Config; pub use utils::error::{Error, Result}; @@ -22,13 +22,13 @@ pub use api::ruma_wrapper::{Ruma, RumaResponse}; use crate::database::KeyValueDatabase; -pub static SERVICES: Cell> = Cell::new(None); +pub static SERVICES: RwLock> = RwLock::new(None); enum ServicesEnum { Rocksdb(Services) } -pub fn services() -> Services { - SERVICES.get().unwrap() +pub fn services() -> Services { + SERVICES.read().unwrap() } diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs new file mode 100644 index 00000000..0f8e0bf5 --- /dev/null +++ b/src/service/account_data/data.rs @@ -0,0 +1,32 @@ +use std::collections::HashMap; + +use ruma::{UserId, RoomId, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw}; +use serde::{Serialize, de::DeserializeOwned}; +use crate::Result; + +pub trait Data { + /// Places one event in the account data of the user and removes the previous entry. + fn update( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + event_type: RoomAccountDataEventType, + data: &T, + ) -> Result<()>; + + /// Searches the account data for a specific kind. + fn get( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + kind: RoomAccountDataEventType, + ) -> Result>; + + /// Returns all changes to the account data that happened after `since`. + fn changes_since( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + since: u64, + ) -> Result>>; +} diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs new file mode 100644 index 00000000..7a399223 --- /dev/null +++ b/src/service/account_data/mod.rs @@ -0,0 +1,158 @@ +mod data; + +pub use data::Data; + +use ruma::{ + api::client::{ + error::ErrorKind, + }, + events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + serde::Raw, + signatures::CanonicalJsonValue, + DeviceId, RoomId, UserId, +}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{collections::HashMap, sync::Arc}; +use tracing::error; + +use crate::{service::*, services, utils, Error, Result}; + +pub struct Service { + db: D, +} + +impl Service { + /// Places one event in the account data of the user and removes the previous entry. + #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] + pub fn update( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + event_type: RoomAccountDataEventType, + data: &T, + ) -> Result<()> { + let mut prefix = room_id + .map(|r| r.to_string()) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xff); + + let mut roomuserdataid = prefix.clone(); + roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + roomuserdataid.push(0xff); + roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); + + let mut key = prefix; + key.extend_from_slice(event_type.to_string().as_bytes()); + + let json = serde_json::to_value(data).expect("all types here can be serialized"); // TODO: maybe add error handling + if json.get("type").is_none() || json.get("content").is_none() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Account data doesn't have all required fields.", + )); + } + + self.roomuserdataid_accountdata.insert( + &roomuserdataid, + &serde_json::to_vec(&json).expect("to_vec always works on json values"), + )?; + + let prev = self.roomusertype_roomuserdataid.get(&key)?; + + self.roomusertype_roomuserdataid + .insert(&key, &roomuserdataid)?; + + // Remove old entry + if let Some(prev) = prev { + self.roomuserdataid_accountdata.remove(&prev)?; + } + + Ok(()) + } + + /// Searches the account data for a specific kind. + #[tracing::instrument(skip(self, room_id, user_id, kind))] + pub fn get( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + kind: RoomAccountDataEventType, + ) -> Result> { + let mut key = room_id + .map(|r| r.to_string()) + .unwrap_or_default() + .as_bytes() + .to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(kind.to_string().as_bytes()); + + self.roomusertype_roomuserdataid + .get(&key)? + .and_then(|roomuserdataid| { + self.roomuserdataid_accountdata + .get(&roomuserdataid) + .transpose() + }) + .transpose()? + .map(|data| { + serde_json::from_slice(&data) + .map_err(|_| Error::bad_database("could not deserialize")) + }) + .transpose() + } + + /// Returns all changes to the account data that happened after `since`. + #[tracing::instrument(skip(self, room_id, user_id, since))] + pub fn changes_since( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + since: u64, + ) -> Result>> { + let mut userdata = HashMap::new(); + + let mut prefix = room_id + .map(|r| r.to_string()) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xff); + + // Skip the data that's exactly at since, because we sent that last time + let mut first_possible = prefix.clone(); + first_possible.extend_from_slice(&(since + 1).to_be_bytes()); + + for r in self + .roomuserdataid_accountdata + .iter_from(&first_possible, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(k, v)| { + Ok::<_, Error>(( + RoomAccountDataEventType::try_from( + utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else( + || Error::bad_database("RoomUserData ID in db is invalid."), + )?) + .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?, + serde_json::from_slice::>(&v).map_err(|_| { + Error::bad_database("Database contains invalid account data.") + })?, + )) + }) + { + let (kind, data) = r?; + userdata.insert(kind, data); + } + + Ok(userdata) + } +} diff --git a/src/service/admin.rs b/src/service/admin.rs deleted file mode 100644 index ded0adb9..00000000 --- a/src/service/admin.rs +++ /dev/null @@ -1,1108 +0,0 @@ -use std::{ - collections::BTreeMap, - convert::{TryFrom, TryInto}, - sync::Arc, - time::Instant, -}; - -use clap::Parser; -use regex::Regex; -use ruma::{ - events::{ - room::{ - canonical_alias::RoomCanonicalAliasEventContent, - create::RoomCreateEventContent, - guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - message::RoomMessageEventContent, - name::RoomNameEventContent, - power_levels::RoomPowerLevelsEventContent, - topic::RoomTopicEventContent, - }, - RoomEventType, - }, - EventId, RoomAliasId, RoomId, RoomName, RoomVersionId, ServerName, UserId, -}; -use serde_json::value::to_raw_value; -use tokio::sync::{mpsc, MutexGuard, RwLock, RwLockReadGuard}; - -use crate::{services, Error, api::{server_server, client_server::AUTO_GEN_PASSWORD_LENGTH}, PduEvent, utils::{HtmlEscape, self}}; - -use super::pdu::PduBuilder; - -#[derive(Debug)] -pub enum AdminRoomEvent { - ProcessMessage(String), - SendMessage(RoomMessageEventContent), -} - -#[derive(Clone)] -pub struct Admin { - pub sender: mpsc::UnboundedSender, -} - -impl Admin { - pub fn start_handler( - &self, - mut receiver: mpsc::UnboundedReceiver, - ) { - tokio::spawn(async move { - // TODO: Use futures when we have long admin commands - //let mut futures = FuturesUnordered::new(); - - let conduit_user = UserId::parse(format!("@conduit:{}", services().globals.server_name())) - .expect("@conduit:server_name is valid"); - - let conduit_room = services() - .rooms - .id_from_alias( - format!("#admins:{}", services().globals.server_name()) - .as_str() - .try_into() - .expect("#admins:server_name is a valid room alias"), - ) - .expect("Database data for admin room alias must be valid") - .expect("Admin room must exist"); - - let send_message = |message: RoomMessageEventContent, - mutex_lock: &MutexGuard<'_, ()>| { - services() - .rooms - .build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMessage, - content: to_raw_value(&message) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - mutex_lock, - ) - .unwrap(); - }; - - loop { - tokio::select! { - Some(event) = receiver.recv() => { - let message_content = match event { - AdminRoomEvent::SendMessage(content) => content, - AdminRoomEvent::ProcessMessage(room_message) => process_admin_message(room_message).await - }; - - let mutex_state = Arc::clone( - services().globals - .roomid_mutex_state - .write() - .unwrap() - .entry(conduit_room.clone()) - .or_default(), - ); - - let state_lock = mutex_state.lock().await; - - send_message(message_content, &state_lock); - - drop(state_lock); - } - } - } - }); - } - - pub fn process_message(&self, room_message: String) { - self.sender - .send(AdminRoomEvent::ProcessMessage(room_message)) - .unwrap(); - } - - pub fn send_message(&self, message_content: RoomMessageEventContent) { - self.sender - .send(AdminRoomEvent::SendMessage(message_content)) - .unwrap(); - } -} - -// Parse and process a message from the admin room -async fn process_admin_message(room_message: String) -> RoomMessageEventContent { - let mut lines = room_message.lines(); - let command_line = lines.next().expect("each string has at least one line"); - let body: Vec<_> = lines.collect(); - - let admin_command = match parse_admin_command(&command_line) { - Ok(command) => command, - Err(error) => { - let server_name = services().globals.server_name(); - let message = error - .to_string() - .replace("server.name", server_name.as_str()); - let html_message = usage_to_html(&message, server_name); - - return RoomMessageEventContent::text_html(message, html_message); - } - }; - - match process_admin_command(admin_command, body).await { - Ok(reply_message) => reply_message, - Err(error) => { - let markdown_message = format!( - "Encountered an error while handling the command:\n\ - ```\n{}\n```", - error, - ); - let html_message = format!( - "Encountered an error while handling the command:\n\ -
\n{}\n
", - error, - ); - - RoomMessageEventContent::text_html(markdown_message, html_message) - } - } -} - -// Parse chat messages from the admin room into an AdminCommand object -fn parse_admin_command(command_line: &str) -> std::result::Result { - // Note: argv[0] is `@conduit:servername:`, which is treated as the main command - let mut argv: Vec<_> = command_line.split_whitespace().collect(); - - // Replace `help command` with `command --help` - // Clap has a help subcommand, but it omits the long help description. - if argv.len() > 1 && argv[1] == "help" { - argv.remove(1); - argv.push("--help"); - } - - // Backwards compatibility with `register_appservice`-style commands - let command_with_dashes; - if argv.len() > 1 && argv[1].contains("_") { - command_with_dashes = argv[1].replace("_", "-"); - argv[1] = &command_with_dashes; - } - - AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) -} - -#[derive(Parser)] -#[clap(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))] -enum AdminCommand { - #[clap(verbatim_doc_comment)] - /// Register an appservice using its registration YAML - /// - /// This command needs a YAML generated by an appservice (such as a bridge), - /// which must be provided in a Markdown code-block below the command. - /// - /// Registering a new bridge using the ID of an existing bridge will replace - /// the old one. - /// - /// [commandbody] - /// # ``` - /// # yaml content here - /// # ``` - RegisterAppservice, - - /// Unregister an appservice using its ID - /// - /// You can find the ID using the `list-appservices` command. - UnregisterAppservice { - /// The appservice to unregister - appservice_identifier: String, - }, - - /// List all the currently registered appservices - ListAppservices, - - /// List all rooms the server knows about - ListRooms, - - /// List users in the database - ListLocalUsers, - - /// List all rooms we are currently handling an incoming pdu from - IncomingFederation, - - /// Deactivate a user - /// - /// User will not be removed from all rooms by default. - /// Use --leave-rooms to force the user to leave all rooms - DeactivateUser { - #[clap(short, long)] - leave_rooms: bool, - user_id: Box, - }, - - #[clap(verbatim_doc_comment)] - /// Deactivate a list of users - /// - /// Recommended to use in conjunction with list-local-users. - /// - /// Users will not be removed from joined rooms by default. - /// Can be overridden with --leave-rooms flag. - /// Removing a mass amount of users from a room may cause a significant amount of leave events. - /// The time to leave rooms may depend significantly on joined rooms and servers. - /// - /// [commandbody] - /// # ``` - /// # User list here - /// # ``` - DeactivateAll { - #[clap(short, long)] - /// Remove users from their joined rooms - leave_rooms: bool, - #[clap(short, long)] - /// Also deactivate admin accounts - force: bool, - }, - - /// Get the auth_chain of a PDU - GetAuthChain { - /// An event ID (the $ character followed by the base64 reference hash) - event_id: Box, - }, - - #[clap(verbatim_doc_comment)] - /// Parse and print a PDU from a JSON - /// - /// The PDU event is only checked for validity and is not added to the - /// database. - /// - /// [commandbody] - /// # ``` - /// # PDU json content here - /// # ``` - ParsePdu, - - /// Retrieve and print a PDU by ID from the Conduit database - GetPdu { - /// An event ID (a $ followed by the base64 reference hash) - event_id: Box, - }, - - /// Print database memory usage statistics - DatabaseMemoryUsage, - - /// Show configuration values - ShowConfig, - - /// Reset user password - ResetPassword { - /// Username of the user for whom the password should be reset - username: String, - }, - - /// Create a new user - CreateUser { - /// Username of the new user - username: String, - /// Password of the new user, if unspecified one is generated - password: Option, - }, - - /// Disables incoming federation handling for a room. - DisableRoom { room_id: Box }, - /// Enables incoming federation handling for a room again. - EnableRoom { room_id: Box }, -} - -async fn process_admin_command( - command: AdminCommand, - body: Vec<&str>, -) -> Result { - let reply_message_content = match command { - AdminCommand::RegisterAppservice => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { - let appservice_config = body[1..body.len() - 1].join("\n"); - let parsed_config = serde_yaml::from_str::(&appservice_config); - match parsed_config { - Ok(yaml) => match services().appservice.register_appservice(yaml) { - Ok(id) => RoomMessageEventContent::text_plain(format!( - "Appservice registered with ID: {}.", - id - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to register appservice: {}", - e - )), - }, - Err(e) => RoomMessageEventContent::text_plain(format!( - "Could not parse appservice config: {}", - e - )), - } - } else { - RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - ) - } - } - AdminCommand::UnregisterAppservice { - appservice_identifier, - } => match services().appservice.unregister_appservice(&appservice_identifier) { - Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to unregister appservice: {}", - e - )), - }, - AdminCommand::ListAppservices => { - if let Ok(appservices) = services().appservice.iter_ids().map(|ids| ids.collect::>()) { - let count = appservices.len(); - let output = format!( - "Appservices ({}): {}", - count, - appservices - .into_iter() - .filter_map(|r| r.ok()) - .collect::>() - .join(", ") - ); - RoomMessageEventContent::text_plain(output) - } else { - RoomMessageEventContent::text_plain("Failed to get appservices.") - } - } - AdminCommand::ListRooms => { - let room_ids = services().rooms.iter_ids(); - let output = format!( - "Rooms:\n{}", - room_ids - .filter_map(|r| r.ok()) - .map(|id| id.to_string() - + "\tMembers: " - + &services() - .rooms - .room_joined_count(&id) - .ok() - .flatten() - .unwrap_or(0) - .to_string()) - .collect::>() - .join("\n") - ); - RoomMessageEventContent::text_plain(output) - } - AdminCommand::ListLocalUsers => match services().users.list_local_users() { - Ok(users) => { - let mut msg: String = format!("Found {} local user account(s):\n", users.len()); - msg += &users.join("\n"); - RoomMessageEventContent::text_plain(&msg) - } - Err(e) => RoomMessageEventContent::text_plain(e.to_string()), - }, - AdminCommand::IncomingFederation => { - let map = services().globals.roomid_federationhandletime.read().unwrap(); - let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); - - for (r, (e, i)) in map.iter() { - let elapsed = i.elapsed(); - msg += &format!( - "{} {}: {}m{}s\n", - r, - e, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - } - RoomMessageEventContent::text_plain(&msg) - } - AdminCommand::GetAuthChain { event_id } => { - let event_id = Arc::::from(event_id); - if let Some(event) = services().rooms.get_pdu_json(&event_id)? { - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { - Error::bad_database("Invalid room id field in event in database") - })?; - let start = Instant::now(); - let count = server_server::get_auth_chain(room_id, vec![event_id]) - .await? - .count(); - let elapsed = start.elapsed(); - RoomMessageEventContent::text_plain(format!( - "Loaded auth chain with length {} in {:?}", - count, elapsed - )) - } else { - RoomMessageEventContent::text_plain("Event not found.") - } - } - AdminCommand::ParsePdu => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { - let string = body[1..body.len() - 1].join("\n"); - match serde_json::from_str(&string) { - Ok(value) => { - match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { - Ok(hash) => { - let event_id = EventId::parse(format!("${}", hash)); - - match serde_json::from_value::( - serde_json::to_value(value).expect("value is json"), - ) { - Ok(pdu) => RoomMessageEventContent::text_plain(format!( - "EventId: {:?}\n{:#?}", - event_id, pdu - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "EventId: {:?}\nCould not parse event: {}", - event_id, e - )), - } - } - Err(e) => RoomMessageEventContent::text_plain(format!( - "Could not parse PDU JSON: {:?}", - e - )), - } - } - Err(e) => RoomMessageEventContent::text_plain(format!( - "Invalid json in command body: {}", - e - )), - } - } else { - RoomMessageEventContent::text_plain("Expected code block in command body.") - } - } - AdminCommand::GetPdu { event_id } => { - let mut outlier = false; - let mut pdu_json = services().rooms.get_non_outlier_pdu_json(&event_id)?; - if pdu_json.is_none() { - outlier = true; - pdu_json = services().rooms.get_pdu_json(&event_id)?; - } - match pdu_json { - Some(json) => { - let json_text = - serde_json::to_string_pretty(&json).expect("canonical json is valid json"); - RoomMessageEventContent::text_html( - format!( - "{}\n```json\n{}\n```", - if outlier { - "PDU is outlier" - } else { - "PDU was accepted" - }, - json_text - ), - format!( - "

{}

\n
{}\n
\n", - if outlier { - "PDU is outlier" - } else { - "PDU was accepted" - }, - HtmlEscape(&json_text) - ), - ) - } - None => RoomMessageEventContent::text_plain("PDU not found."), - } - } - AdminCommand::DatabaseMemoryUsage => match services()._db.memory_usage() { - Ok(response) => RoomMessageEventContent::text_plain(response), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to get database memory usage: {}", - e - )), - }, - AdminCommand::ShowConfig => { - // Construct and send the response - RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) - } - AdminCommand::ResetPassword { username } => { - let user_id = match UserId::parse_with_server_name( - username.as_str().to_lowercase(), - services().globals.server_name(), - ) { - Ok(id) => id, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "The supplied username is not a valid username: {}", - e - ))) - } - }; - - // Check if the specified user is valid - if !services().users.exists(&user_id)? - || services().users.is_deactivated(&user_id)? - || user_id - == UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("conduit user exists") - { - return Ok(RoomMessageEventContent::text_plain( - "The specified user does not exist or is deactivated!", - )); - } - - let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); - - match services().users.set_password(&user_id, Some(new_password.as_str())) { - Ok(()) => RoomMessageEventContent::text_plain(format!( - "Successfully reset the password for user {}: {}", - user_id, new_password - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Couldn't reset the password for user {}: {}", - user_id, e - )), - } - } - AdminCommand::CreateUser { username, password } => { - let password = password.unwrap_or(utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); - // Validate user id - let user_id = match UserId::parse_with_server_name( - username.as_str().to_lowercase(), - services().globals.server_name(), - ) { - Ok(id) => id, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "The supplied username is not a valid username: {}", - e - ))) - } - }; - if user_id.is_historical() { - return Ok(RoomMessageEventContent::text_plain(format!( - "userid {user_id} is not allowed due to historical" - ))); - } - if services().users.exists(&user_id)? { - return Ok(RoomMessageEventContent::text_plain(format!( - "userid {user_id} already exists" - ))); - } - // Create user - services().users.create(&user_id, Some(password.as_str()))?; - - // Default to pretty displayname - let displayname = format!("{} ⚡️", user_id.localpart()); - services().users - .set_displayname(&user_id, Some(displayname.clone()))?; - - // Initial account data - services().account_data.update( - None, - &user_id, - ruma::events::GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: ruma::push::Ruleset::server_default(&user_id), - }, - }, - )?; - - // we dont add a device since we're not the user, just the creator - - // Inhibit login does not work for guests - RoomMessageEventContent::text_plain(format!( - "Created user with user_id: {user_id} and password: {password}" - )) - } - AdminCommand::DisableRoom { room_id } => { - services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; - RoomMessageEventContent::text_plain("Room disabled.") - } - AdminCommand::EnableRoom { room_id } => { - services().rooms.disabledroomids.remove(room_id.as_bytes())?; - RoomMessageEventContent::text_plain("Room enabled.") - } - AdminCommand::DeactivateUser { - leave_rooms, - user_id, - } => { - let user_id = Arc::::from(user_id); - if services().users.exists(&user_id)? { - RoomMessageEventContent::text_plain(format!( - "Making {} leave all rooms before deactivation...", - user_id - )); - - services().users.deactivate_account(&user_id)?; - - if leave_rooms { - services().rooms.leave_all_rooms(&user_id).await?; - } - - RoomMessageEventContent::text_plain(format!( - "User {} has been deactivated", - user_id - )) - } else { - RoomMessageEventContent::text_plain(format!( - "User {} doesn't exist on this server", - user_id - )) - } - } - AdminCommand::DeactivateAll { leave_rooms, force } => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { - let usernames = body.clone().drain(1..body.len() - 1).collect::>(); - - let mut user_ids: Vec<&UserId> = Vec::new(); - - for &username in &usernames { - match <&UserId>::try_from(username) { - Ok(user_id) => user_ids.push(user_id), - Err(_) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "{} is not a valid username", - username - ))) - } - } - } - - let mut deactivation_count = 0; - let mut admins = Vec::new(); - - if !force { - user_ids.retain(|&user_id| { - match services().users.is_admin(user_id) { - Ok(is_admin) => match is_admin { - true => { - admins.push(user_id.localpart()); - false - } - false => true, - }, - Err(_) => false, - } - }) - } - - for &user_id in &user_ids { - match services().users.deactivate_account(user_id) { - Ok(_) => deactivation_count += 1, - Err(_) => {} - } - } - - if leave_rooms { - for &user_id in &user_ids { - let _ = services().rooms.leave_all_rooms(user_id).await; - } - } - - if admins.is_empty() { - RoomMessageEventContent::text_plain(format!( - "Deactivated {} accounts.", - deactivation_count - )) - } else { - RoomMessageEventContent::text_plain(format!("Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts", deactivation_count, admins.join(", "))) - } - } else { - RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - ) - } - } - }; - - Ok(reply_message_content) -} - -// Utility to turn clap's `--help` text to HTML. -fn usage_to_html(text: &str, server_name: &ServerName) -> String { - // Replace `@conduit:servername:-subcmdname` with `@conduit:servername: subcmdname` - let text = text.replace( - &format!("@conduit:{}:-", server_name), - &format!("@conduit:{}: ", server_name), - ); - - // For the conduit admin room, subcommands become main commands - let text = text.replace("SUBCOMMAND", "COMMAND"); - let text = text.replace("subcommand", "command"); - - // Escape option names (e.g. ``) since they look like HTML tags - let text = text.replace("<", "<").replace(">", ">"); - - // Italicize the first line (command name and version text) - let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); - let text = re.replace_all(&text, "$1\n"); - - // Unmerge wrapped lines - let text = text.replace("\n ", " "); - - // Wrap option names in backticks. The lines look like: - // -V, --version Prints version information - // And are converted to: - // -V, --version: Prints version information - // (?m) enables multi-line mode for ^ and $ - let re = Regex::new("(?m)^ (([a-zA-Z_&;-]+(, )?)+) +(.*)$") - .expect("Regex compilation should not fail"); - let text = re.replace_all(&text, "$1: $4"); - - // Look for a `[commandbody]` tag. If it exists, use all lines below it that - // start with a `#` in the USAGE section. - let mut text_lines: Vec<&str> = text.lines().collect(); - let mut command_body = String::new(); - - if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { - text_lines.remove(line_index); - - while text_lines - .get(line_index) - .map(|line| line.starts_with("#")) - .unwrap_or(false) - { - command_body += if text_lines[line_index].starts_with("# ") { - &text_lines[line_index][2..] - } else { - &text_lines[line_index][1..] - }; - command_body += "[nobr]\n"; - text_lines.remove(line_index); - } - } - - let text = text_lines.join("\n"); - - // Improve the usage section - let text = if command_body.is_empty() { - // Wrap the usage line in code tags - let re = Regex::new("(?m)^USAGE:\n (@conduit:.*)$") - .expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n$1").to_string() - } else { - // Wrap the usage line in a code block, and add a yaml block example - // This makes the usage of e.g. `register-appservice` more accurate - let re = - Regex::new("(?m)^USAGE:\n (.*?)\n\n").expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n
$1[nobr]\n[commandbodyblock]
") - .replace("[commandbodyblock]", &command_body) - }; - - // Add HTML line-breaks - let text = text - .replace("\n\n\n", "\n\n") - .replace("\n", "
\n") - .replace("[nobr]
", ""); - - text -} - -/// Create the admin room. -/// -/// Users in this room are considered admins by conduit, and the room can be -/// used to issue admin commands by talking to the server user inside it. -pub(crate) async fn create_admin_room() -> Result<()> { - let room_id = RoomId::new(services().globals.server_name()); - - services().rooms.get_or_create_shortroomid(&room_id)?; - - let mutex_state = Arc::clone( - services().globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Create a user for the server - let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("@conduit:server_name is valid"); - - services().users.create(&conduit_user, None)?; - - let mut content = RoomCreateEventContent::new(conduit_user.clone()); - content.federate = true; - content.predecessor = None; - content.room_version = RoomVersionId::V6; - - // 1. The room create event - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - // 2. Make conduit bot join - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(conduit_user.to_string()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - // 3. Power levels - let mut users = BTreeMap::new(); - users.insert(conduit_user.clone(), 100.into()); - - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - // 4.1 Join Rules - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - // 4.2 History Visibility - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new( - HistoryVisibility::Shared, - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - // 4.3 Guest Access - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - // 5. Events implied by name and topic - let room_name = RoomName::parse(format!("{} Admin Room", services().globals.server_name())) - .expect("Room name is valid"); - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(Some(room_name))) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomTopic, - content: to_raw_value(&RoomTopicEventContent { - topic: format!("Manage {}", services().globals.server_name()), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - // 6. Room alias - let alias: Box = format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomCanonicalAlias, - content: to_raw_value(&RoomCanonicalAliasEventContent { - alias: Some(alias.clone()), - alt_aliases: Vec::new(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - services().rooms.set_alias(&alias, Some(&room_id))?; - - Ok(()) -} - -/// Invite the user to the conduit admin room. -/// -/// In conduit, this is equivalent to granting admin privileges. -pub(crate) async fn make_user_admin( - user_id: &UserId, - displayname: String, -) -> Result<()> { - let admin_room_alias: Box = format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - let room_id = services() - .rooms - .id_from_alias(&admin_room_alias)? - .expect("Admin room must exist"); - - let mutex_state = Arc::clone( - services().globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Use the server user to grant the new admin's power level - let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("@conduit:server_name is valid"); - - // Invite and join the real user - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: Some(displayname), - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - &user_id, - &room_id, - &state_lock, - )?; - - // Set power level - let mut users = BTreeMap::new(); - users.insert(conduit_user.to_owned(), 100.into()); - users.insert(user_id.to_owned(), 100.into()); - - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - // Send welcome message - services().rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMessage, - content: to_raw_value(&RoomMessageEventContent::text_html( - format!("## Thank you for trying out Conduit!\n\nConduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Website: https://conduit.rs\n> Git and Documentation: https://gitlab.com/famedly/conduit\n> Report issues: https://gitlab.com/famedly/conduit/-/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nConduit room (Ask questions and get notified on updates):\n`/join #conduit:fachschaften.org`\n\nConduit lounge (Off-topic, only Conduit users are allowed to join)\n`/join #conduit-lounge:conduit.rs`", services().globals.server_name()).to_owned(), - format!("

Thank you for trying out Conduit!

\n

Conduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.

\n

Helpful links:

\n
\n

Website: https://conduit.rs
Git and Documentation: https://gitlab.com/famedly/conduit
Report issues: https://gitlab.com/famedly/conduit/-/issues

\n
\n

For a list of available commands, send the following message in this room: @conduit:{}: --help

\n

Here are some rooms you can join (by typing the command):

\n

Conduit room (Ask questions and get notified on updates):
/join #conduit:fachschaften.org

\n

Conduit lounge (Off-topic, only Conduit users are allowed to join)
/join #conduit-lounge:conduit.rs

\n", services().globals.server_name()).to_owned(), - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - )?; - - Ok(()) -} diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs new file mode 100644 index 00000000..dad4ceba --- /dev/null +++ b/src/service/admin/mod.rs @@ -0,0 +1,1111 @@ +use std::{ + collections::BTreeMap, + convert::{TryFrom, TryInto}, + sync::Arc, + time::Instant, +}; + +use clap::Parser; +use regex::Regex; +use ruma::{ + events::{ + room::{ + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + message::RoomMessageEventContent, + name::RoomNameEventContent, + power_levels::RoomPowerLevelsEventContent, + topic::RoomTopicEventContent, + }, + RoomEventType, + }, + EventId, RoomAliasId, RoomId, RoomName, RoomVersionId, ServerName, UserId, +}; +use serde_json::value::to_raw_value; +use tokio::sync::{mpsc, MutexGuard, RwLock, RwLockReadGuard}; + +use crate::{Result, services, Error, api::{server_server, client_server::AUTO_GEN_PASSWORD_LENGTH}, PduEvent, utils::{HtmlEscape, self}}; + +use super::pdu::PduBuilder; + +#[derive(Parser)] +#[clap(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))] +enum AdminCommand { + #[clap(verbatim_doc_comment)] + /// Register an appservice using its registration YAML + /// + /// This command needs a YAML generated by an appservice (such as a bridge), + /// which must be provided in a Markdown code-block below the command. + /// + /// Registering a new bridge using the ID of an existing bridge will replace + /// the old one. + /// + /// [commandbody] + /// # ``` + /// # yaml content here + /// # ``` + RegisterAppservice, + + /// Unregister an appservice using its ID + /// + /// You can find the ID using the `list-appservices` command. + UnregisterAppservice { + /// The appservice to unregister + appservice_identifier: String, + }, + + /// List all the currently registered appservices + ListAppservices, + + /// List all rooms the server knows about + ListRooms, + + /// List users in the database + ListLocalUsers, + + /// List all rooms we are currently handling an incoming pdu from + IncomingFederation, + + /// Deactivate a user + /// + /// User will not be removed from all rooms by default. + /// Use --leave-rooms to force the user to leave all rooms + DeactivateUser { + #[clap(short, long)] + leave_rooms: bool, + user_id: Box, + }, + + #[clap(verbatim_doc_comment)] + /// Deactivate a list of users + /// + /// Recommended to use in conjunction with list-local-users. + /// + /// Users will not be removed from joined rooms by default. + /// Can be overridden with --leave-rooms flag. + /// Removing a mass amount of users from a room may cause a significant amount of leave events. + /// The time to leave rooms may depend significantly on joined rooms and servers. + /// + /// [commandbody] + /// # ``` + /// # User list here + /// # ``` + DeactivateAll { + #[clap(short, long)] + /// Remove users from their joined rooms + leave_rooms: bool, + #[clap(short, long)] + /// Also deactivate admin accounts + force: bool, + }, + + /// Get the auth_chain of a PDU + GetAuthChain { + /// An event ID (the $ character followed by the base64 reference hash) + event_id: Box, + }, + + #[clap(verbatim_doc_comment)] + /// Parse and print a PDU from a JSON + /// + /// The PDU event is only checked for validity and is not added to the + /// database. + /// + /// [commandbody] + /// # ``` + /// # PDU json content here + /// # ``` + ParsePdu, + + /// Retrieve and print a PDU by ID from the Conduit database + GetPdu { + /// An event ID (a $ followed by the base64 reference hash) + event_id: Box, + }, + + /// Print database memory usage statistics + DatabaseMemoryUsage, + + /// Show configuration values + ShowConfig, + + /// Reset user password + ResetPassword { + /// Username of the user for whom the password should be reset + username: String, + }, + + /// Create a new user + CreateUser { + /// Username of the new user + username: String, + /// Password of the new user, if unspecified one is generated + password: Option, + }, + + /// Disables incoming federation handling for a room. + DisableRoom { room_id: Box }, + /// Enables incoming federation handling for a room again. + EnableRoom { room_id: Box }, +} + + +#[derive(Debug)] +pub enum AdminRoomEvent { + ProcessMessage(String), + SendMessage(RoomMessageEventContent), +} + +#[derive(Clone)] +pub struct Service { + pub sender: mpsc::UnboundedSender, +} + +impl Service { + pub fn start_handler( + &self, + mut receiver: mpsc::UnboundedReceiver, + ) { + tokio::spawn(async move { + // TODO: Use futures when we have long admin commands + //let mut futures = FuturesUnordered::new(); + + let conduit_user = UserId::parse(format!("@conduit:{}", services().globals.server_name())) + .expect("@conduit:server_name is valid"); + + let conduit_room = services() + .rooms + .id_from_alias( + format!("#admins:{}", services().globals.server_name()) + .as_str() + .try_into() + .expect("#admins:server_name is a valid room alias"), + ) + .expect("Database data for admin room alias must be valid") + .expect("Admin room must exist"); + + let send_message = |message: RoomMessageEventContent, + mutex_lock: &MutexGuard<'_, ()>| { + services() + .rooms + .build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMessage, + content: to_raw_value(&message) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + mutex_lock, + ) + .unwrap(); + }; + + loop { + tokio::select! { + Some(event) = receiver.recv() => { + let message_content = match event { + AdminRoomEvent::SendMessage(content) => content, + AdminRoomEvent::ProcessMessage(room_message) => process_admin_message(room_message).await + }; + + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(conduit_room.clone()) + .or_default(), + ); + + let state_lock = mutex_state.lock().await; + + send_message(message_content, &state_lock); + + drop(state_lock); + } + } + } + }); + } + + pub fn process_message(&self, room_message: String) { + self.sender + .send(AdminRoomEvent::ProcessMessage(room_message)) + .unwrap(); + } + + pub fn send_message(&self, message_content: RoomMessageEventContent) { + self.sender + .send(AdminRoomEvent::SendMessage(message_content)) + .unwrap(); + } + + // Parse and process a message from the admin room + async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent { + let mut lines = room_message.lines(); + let command_line = lines.next().expect("each string has at least one line"); + let body: Vec<_> = lines.collect(); + + let admin_command = match parse_admin_command(&command_line) { + Ok(command) => command, + Err(error) => { + let server_name = services().globals.server_name(); + let message = error + .to_string() + .replace("server.name", server_name.as_str()); + let html_message = usage_to_html(&message, server_name); + + return RoomMessageEventContent::text_html(message, html_message); + } + }; + + match process_admin_command(admin_command, body).await { + Ok(reply_message) => reply_message, + Err(error) => { + let markdown_message = format!( + "Encountered an error while handling the command:\n\ + ```\n{}\n```", + error, + ); + let html_message = format!( + "Encountered an error while handling the command:\n\ +
\n{}\n
", + error, + ); + + RoomMessageEventContent::text_html(markdown_message, html_message) + } + } + } + + // Parse chat messages from the admin room into an AdminCommand object + fn parse_admin_command(&self, command_line: &str) -> std::result::Result { + // Note: argv[0] is `@conduit:servername:`, which is treated as the main command + let mut argv: Vec<_> = command_line.split_whitespace().collect(); + + // Replace `help command` with `command --help` + // Clap has a help subcommand, but it omits the long help description. + if argv.len() > 1 && argv[1] == "help" { + argv.remove(1); + argv.push("--help"); + } + + // Backwards compatibility with `register_appservice`-style commands + let command_with_dashes; + if argv.len() > 1 && argv[1].contains("_") { + command_with_dashes = argv[1].replace("_", "-"); + argv[1] = &command_with_dashes; + } + + AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) + } + + async fn process_admin_command( + &self, + command: AdminCommand, + body: Vec<&str>, + ) -> Result { + let reply_message_content = match command { + AdminCommand::RegisterAppservice => { + if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { + let appservice_config = body[1..body.len() - 1].join("\n"); + let parsed_config = serde_yaml::from_str::(&appservice_config); + match parsed_config { + Ok(yaml) => match services().appservice.register_appservice(yaml) { + Ok(id) => RoomMessageEventContent::text_plain(format!( + "Appservice registered with ID: {}.", + id + )), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Failed to register appservice: {}", + e + )), + }, + Err(e) => RoomMessageEventContent::text_plain(format!( + "Could not parse appservice config: {}", + e + )), + } + } else { + RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + ) + } + } + AdminCommand::UnregisterAppservice { + appservice_identifier, + } => match services().appservice.unregister_appservice(&appservice_identifier) { + Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Failed to unregister appservice: {}", + e + )), + }, + AdminCommand::ListAppservices => { + if let Ok(appservices) = services().appservice.iter_ids().map(|ids| ids.collect::>()) { + let count = appservices.len(); + let output = format!( + "Appservices ({}): {}", + count, + appservices + .into_iter() + .filter_map(|r| r.ok()) + .collect::>() + .join(", ") + ); + RoomMessageEventContent::text_plain(output) + } else { + RoomMessageEventContent::text_plain("Failed to get appservices.") + } + } + AdminCommand::ListRooms => { + let room_ids = services().rooms.iter_ids(); + let output = format!( + "Rooms:\n{}", + room_ids + .filter_map(|r| r.ok()) + .map(|id| id.to_string() + + "\tMembers: " + + &services() + .rooms + .room_joined_count(&id) + .ok() + .flatten() + .unwrap_or(0) + .to_string()) + .collect::>() + .join("\n") + ); + RoomMessageEventContent::text_plain(output) + } + AdminCommand::ListLocalUsers => match services().users.list_local_users() { + Ok(users) => { + let mut msg: String = format!("Found {} local user account(s):\n", users.len()); + msg += &users.join("\n"); + RoomMessageEventContent::text_plain(&msg) + } + Err(e) => RoomMessageEventContent::text_plain(e.to_string()), + }, + AdminCommand::IncomingFederation => { + let map = services().globals.roomid_federationhandletime.read().unwrap(); + let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); + + for (r, (e, i)) in map.iter() { + let elapsed = i.elapsed(); + msg += &format!( + "{} {}: {}m{}s\n", + r, + e, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); + } + RoomMessageEventContent::text_plain(&msg) + } + AdminCommand::GetAuthChain { event_id } => { + let event_id = Arc::::from(event_id); + if let Some(event) = services().rooms.get_pdu_json(&event_id)? { + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { + Error::bad_database("Invalid room id field in event in database") + })?; + let start = Instant::now(); + let count = server_server::get_auth_chain(room_id, vec![event_id]) + .await? + .count(); + let elapsed = start.elapsed(); + RoomMessageEventContent::text_plain(format!( + "Loaded auth chain with length {} in {:?}", + count, elapsed + )) + } else { + RoomMessageEventContent::text_plain("Event not found.") + } + } + AdminCommand::ParsePdu => { + if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { + let string = body[1..body.len() - 1].join("\n"); + match serde_json::from_str(&string) { + Ok(value) => { + match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { + Ok(hash) => { + let event_id = EventId::parse(format!("${}", hash)); + + match serde_json::from_value::( + serde_json::to_value(value).expect("value is json"), + ) { + Ok(pdu) => RoomMessageEventContent::text_plain(format!( + "EventId: {:?}\n{:#?}", + event_id, pdu + )), + Err(e) => RoomMessageEventContent::text_plain(format!( + "EventId: {:?}\nCould not parse event: {}", + event_id, e + )), + } + } + Err(e) => RoomMessageEventContent::text_plain(format!( + "Could not parse PDU JSON: {:?}", + e + )), + } + } + Err(e) => RoomMessageEventContent::text_plain(format!( + "Invalid json in command body: {}", + e + )), + } + } else { + RoomMessageEventContent::text_plain("Expected code block in command body.") + } + } + AdminCommand::GetPdu { event_id } => { + let mut outlier = false; + let mut pdu_json = services().rooms.get_non_outlier_pdu_json(&event_id)?; + if pdu_json.is_none() { + outlier = true; + pdu_json = services().rooms.get_pdu_json(&event_id)?; + } + match pdu_json { + Some(json) => { + let json_text = + serde_json::to_string_pretty(&json).expect("canonical json is valid json"); + RoomMessageEventContent::text_html( + format!( + "{}\n```json\n{}\n```", + if outlier { + "PDU is outlier" + } else { + "PDU was accepted" + }, + json_text + ), + format!( + "

{}

\n
{}\n
\n", + if outlier { + "PDU is outlier" + } else { + "PDU was accepted" + }, + HtmlEscape(&json_text) + ), + ) + } + None => RoomMessageEventContent::text_plain("PDU not found."), + } + } + AdminCommand::DatabaseMemoryUsage => match services()._db.memory_usage() { + Ok(response) => RoomMessageEventContent::text_plain(response), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Failed to get database memory usage: {}", + e + )), + }, + AdminCommand::ShowConfig => { + // Construct and send the response + RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) + } + AdminCommand::ResetPassword { username } => { + let user_id = match UserId::parse_with_server_name( + username.as_str().to_lowercase(), + services().globals.server_name(), + ) { + Ok(id) => id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "The supplied username is not a valid username: {}", + e + ))) + } + }; + + // Check if the specified user is valid + if !services().users.exists(&user_id)? + || services().users.is_deactivated(&user_id)? + || user_id + == UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("conduit user exists") + { + return Ok(RoomMessageEventContent::text_plain( + "The specified user does not exist or is deactivated!", + )); + } + + let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); + + match services().users.set_password(&user_id, Some(new_password.as_str())) { + Ok(()) => RoomMessageEventContent::text_plain(format!( + "Successfully reset the password for user {}: {}", + user_id, new_password + )), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Couldn't reset the password for user {}: {}", + user_id, e + )), + } + } + AdminCommand::CreateUser { username, password } => { + let password = password.unwrap_or(utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); + // Validate user id + let user_id = match UserId::parse_with_server_name( + username.as_str().to_lowercase(), + services().globals.server_name(), + ) { + Ok(id) => id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "The supplied username is not a valid username: {}", + e + ))) + } + }; + if user_id.is_historical() { + return Ok(RoomMessageEventContent::text_plain(format!( + "userid {user_id} is not allowed due to historical" + ))); + } + if services().users.exists(&user_id)? { + return Ok(RoomMessageEventContent::text_plain(format!( + "userid {user_id} already exists" + ))); + } + // Create user + services().users.create(&user_id, Some(password.as_str()))?; + + // Default to pretty displayname + let displayname = format!("{} ⚡️", user_id.localpart()); + services().users + .set_displayname(&user_id, Some(displayname.clone()))?; + + // Initial account data + services().account_data.update( + None, + &user_id, + ruma::events::GlobalAccountDataEventType::PushRules + .to_string() + .into(), + &ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: ruma::push::Ruleset::server_default(&user_id), + }, + }, + )?; + + // we dont add a device since we're not the user, just the creator + + // Inhibit login does not work for guests + RoomMessageEventContent::text_plain(format!( + "Created user with user_id: {user_id} and password: {password}" + )) + } + AdminCommand::DisableRoom { room_id } => { + services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; + RoomMessageEventContent::text_plain("Room disabled.") + } + AdminCommand::EnableRoom { room_id } => { + services().rooms.disabledroomids.remove(room_id.as_bytes())?; + RoomMessageEventContent::text_plain("Room enabled.") + } + AdminCommand::DeactivateUser { + leave_rooms, + user_id, + } => { + let user_id = Arc::::from(user_id); + if services().users.exists(&user_id)? { + RoomMessageEventContent::text_plain(format!( + "Making {} leave all rooms before deactivation...", + user_id + )); + + services().users.deactivate_account(&user_id)?; + + if leave_rooms { + services().rooms.leave_all_rooms(&user_id).await?; + } + + RoomMessageEventContent::text_plain(format!( + "User {} has been deactivated", + user_id + )) + } else { + RoomMessageEventContent::text_plain(format!( + "User {} doesn't exist on this server", + user_id + )) + } + } + AdminCommand::DeactivateAll { leave_rooms, force } => { + if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { + let usernames = body.clone().drain(1..body.len() - 1).collect::>(); + + let mut user_ids: Vec<&UserId> = Vec::new(); + + for &username in &usernames { + match <&UserId>::try_from(username) { + Ok(user_id) => user_ids.push(user_id), + Err(_) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "{} is not a valid username", + username + ))) + } + } + } + + let mut deactivation_count = 0; + let mut admins = Vec::new(); + + if !force { + user_ids.retain(|&user_id| { + match services().users.is_admin(user_id) { + Ok(is_admin) => match is_admin { + true => { + admins.push(user_id.localpart()); + false + } + false => true, + }, + Err(_) => false, + } + }) + } + + for &user_id in &user_ids { + match services().users.deactivate_account(user_id) { + Ok(_) => deactivation_count += 1, + Err(_) => {} + } + } + + if leave_rooms { + for &user_id in &user_ids { + let _ = services().rooms.leave_all_rooms(user_id).await; + } + } + + if admins.is_empty() { + RoomMessageEventContent::text_plain(format!( + "Deactivated {} accounts.", + deactivation_count + )) + } else { + RoomMessageEventContent::text_plain(format!("Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts", deactivation_count, admins.join(", "))) + } + } else { + RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + ) + } + } + }; + + Ok(reply_message_content) + } + + // Utility to turn clap's `--help` text to HTML. + fn usage_to_html(&self, text: &str, server_name: &ServerName) -> String { + // Replace `@conduit:servername:-subcmdname` with `@conduit:servername: subcmdname` + let text = text.replace( + &format!("@conduit:{}:-", server_name), + &format!("@conduit:{}: ", server_name), + ); + + // For the conduit admin room, subcommands become main commands + let text = text.replace("SUBCOMMAND", "COMMAND"); + let text = text.replace("subcommand", "command"); + + // Escape option names (e.g. ``) since they look like HTML tags + let text = text.replace("<", "<").replace(">", ">"); + + // Italicize the first line (command name and version text) + let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); + let text = re.replace_all(&text, "$1\n"); + + // Unmerge wrapped lines + let text = text.replace("\n ", " "); + + // Wrap option names in backticks. The lines look like: + // -V, --version Prints version information + // And are converted to: + // -V, --version: Prints version information + // (?m) enables multi-line mode for ^ and $ + let re = Regex::new("(?m)^ (([a-zA-Z_&;-]+(, )?)+) +(.*)$") + .expect("Regex compilation should not fail"); + let text = re.replace_all(&text, "$1: $4"); + + // Look for a `[commandbody]` tag. If it exists, use all lines below it that + // start with a `#` in the USAGE section. + let mut text_lines: Vec<&str> = text.lines().collect(); + let mut command_body = String::new(); + + if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { + text_lines.remove(line_index); + + while text_lines + .get(line_index) + .map(|line| line.starts_with("#")) + .unwrap_or(false) + { + command_body += if text_lines[line_index].starts_with("# ") { + &text_lines[line_index][2..] + } else { + &text_lines[line_index][1..] + }; + command_body += "[nobr]\n"; + text_lines.remove(line_index); + } + } + + let text = text_lines.join("\n"); + + // Improve the usage section + let text = if command_body.is_empty() { + // Wrap the usage line in code tags + let re = Regex::new("(?m)^USAGE:\n (@conduit:.*)$") + .expect("Regex compilation should not fail"); + re.replace_all(&text, "USAGE:\n$1").to_string() + } else { + // Wrap the usage line in a code block, and add a yaml block example + // This makes the usage of e.g. `register-appservice` more accurate + let re = + Regex::new("(?m)^USAGE:\n (.*?)\n\n").expect("Regex compilation should not fail"); + re.replace_all(&text, "USAGE:\n
$1[nobr]\n[commandbodyblock]
") + .replace("[commandbodyblock]", &command_body) + }; + + // Add HTML line-breaks + let text = text + .replace("\n\n\n", "\n\n") + .replace("\n", "
\n") + .replace("[nobr]
", ""); + + text + } + + /// Create the admin room. + /// + /// Users in this room are considered admins by conduit, and the room can be + /// used to issue admin commands by talking to the server user inside it. + pub(crate) async fn create_admin_room(&self) -> Result<()> { + let room_id = RoomId::new(services().globals.server_name()); + + services().rooms.get_or_create_shortroomid(&room_id)?; + + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Create a user for the server + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); + + services().users.create(&conduit_user, None)?; + + let mut content = RoomCreateEventContent::new(conduit_user.clone()); + content.federate = true; + content.predecessor = None; + content.room_version = RoomVersionId::V6; + + // 1. The room create event + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomCreate, + content: to_raw_value(&content).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 2. Make conduit bot join + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(conduit_user.to_string()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 3. Power levels + let mut users = BTreeMap::new(); + users.insert(conduit_user.clone(), 100.into()); + + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 4.1 Join Rules + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomJoinRules, + content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 4.2 History Visibility + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomHistoryVisibility, + content: to_raw_value(&RoomHistoryVisibilityEventContent::new( + HistoryVisibility::Shared, + )) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 4.3 Guest Access + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomGuestAccess, + content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 5. Events implied by name and topic + let room_name = RoomName::parse(format!("{} Admin Room", services().globals.server_name())) + .expect("Room name is valid"); + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomName, + content: to_raw_value(&RoomNameEventContent::new(Some(room_name))) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomTopic, + content: to_raw_value(&RoomTopicEventContent { + topic: format!("Manage {}", services().globals.server_name()), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 6. Room alias + let alias: Box = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomCanonicalAlias, + content: to_raw_value(&RoomCanonicalAliasEventContent { + alias: Some(alias.clone()), + alt_aliases: Vec::new(), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + services().rooms.set_alias(&alias, Some(&room_id))?; + + Ok(()) + } + + /// Invite the user to the conduit admin room. + /// + /// In conduit, this is equivalent to granting admin privileges. + pub(crate) async fn make_user_admin( + &self, + user_id: &UserId, + displayname: String, + ) -> Result<()> { + let admin_room_alias: Box = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + let room_id = services() + .rooms + .id_from_alias(&admin_room_alias)? + .expect("Admin room must exist"); + + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Use the server user to grant the new admin's power level + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); + + // Invite and join the real user + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: Some(displayname), + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + &user_id, + &room_id, + &state_lock, + )?; + + // Set power level + let mut users = BTreeMap::new(); + users.insert(conduit_user.to_owned(), 100.into()); + users.insert(user_id.to_owned(), 100.into()); + + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // Send welcome message + services().rooms.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMessage, + content: to_raw_value(&RoomMessageEventContent::text_html( + format!("## Thank you for trying out Conduit!\n\nConduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Website: https://conduit.rs\n> Git and Documentation: https://gitlab.com/famedly/conduit\n> Report issues: https://gitlab.com/famedly/conduit/-/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nConduit room (Ask questions and get notified on updates):\n`/join #conduit:fachschaften.org`\n\nConduit lounge (Off-topic, only Conduit users are allowed to join)\n`/join #conduit-lounge:conduit.rs`", services().globals.server_name()).to_owned(), + format!("

Thank you for trying out Conduit!

\n

Conduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.

\n

Helpful links:

\n
\n

Website: https://conduit.rs
Git and Documentation: https://gitlab.com/famedly/conduit
Report issues: https://gitlab.com/famedly/conduit/-/issues

\n
\n

For a list of available commands, send the following message in this room: @conduit:{}: --help

\n

Here are some rooms you can join (by typing the command):

\n

Conduit room (Ask questions and get notified on updates):
/join #conduit:fachschaften.org

\n

Conduit lounge (Off-topic, only Conduit users are allowed to join)
/join #conduit-lounge:conduit.rs

\n", services().globals.server_name()).to_owned(), + )) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + Ok(()) + } +} diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index eed84d59..cd48e85d 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,5 +1,6 @@ +use crate::Result; + pub trait Data { - type Iter: Iterator; /// Registers an appservice and returns the ID to the caller fn register_appservice(&self, yaml: serde_yaml::Value) -> Result; @@ -12,7 +13,7 @@ pub trait Data { fn get_registration(&self, id: &str) -> Result>; - fn iter_ids(&self) -> Result>>; + fn iter_ids(&self) -> Result>>>; fn all(&self) -> Result>; } diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index ec4ffc56..63fa3afe 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,13 +1,13 @@ mod data; pub use data::Data; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Registers an appservice and returns the ID to the caller pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { self.db.register_appservice(yaml) diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs new file mode 100644 index 00000000..f36ab61b --- /dev/null +++ b/src/service/globals/data.rs @@ -0,0 +1,8 @@ +use ruma::signatures::Ed25519KeyPair; + +use crate::Result; + +pub trait Data { + fn load_keypair(&self) -> Result; + fn remove_keypair(&self) -> Result<()>; +} diff --git a/src/service/globals.rs b/src/service/globals/mod.rs similarity index 90% rename from src/service/globals.rs rename to src/service/globals/mod.rs index 2b47e5b1..556ca71c 100644 --- a/src/service/globals.rs +++ b/src/service/globals/mod.rs @@ -3,7 +3,7 @@ pub use data::Data; use crate::service::*; -use crate::{database::Config, server_server::FedDest, utils, Error, Result}; +use crate::{Config, utils, Error, Result}; use ruma::{ api::{ client::sync::sync_events, @@ -25,8 +25,6 @@ use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; use tracing::error; use trust_dns_resolver::TokioAsyncResolver; -use super::abstraction::Tree; - pub const COUNTER: &[u8] = b"c"; type WellKnownMap = HashMap, (FedDest, String)>; @@ -93,47 +91,18 @@ impl Default for RotationHandler { } -impl Service<_> { +impl Service { pub fn load( - globals: Arc, - server_signingkeys: Arc, + db: D, config: Config, ) -> Result { - let keypair_bytes = globals.get(b"keypair")?.map_or_else( - || { - let keypair = utils::generate_keypair(); - globals.insert(b"keypair", &keypair)?; - Ok::<_, Error>(keypair) - }, - |s| Ok(s.to_vec()), - )?; - - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); - - let keypair = utils::string_from_bytes( - // 1. version - parts - .next() - .expect("splitn always returns at least one element"), - ) - .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) - .and_then(|version| { - // 2. key - parts - .next() - .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) - .map(|key| (version, key)) - }) - .and_then(|(version, key)| { - ruma::signatures::Ed25519KeyPair::from_der(key, version) - .map_err(|_| Error::bad_database("Private or public keys are invalid.")) - }); + let keypair = db.load_keypair(); let keypair = match keypair { Ok(k) => k, Err(e) => { error!("Keypair invalid. Deleting..."); - globals.remove(b"keypair")?; + db.remove_keypair(); return Err(e); } }; @@ -167,7 +136,7 @@ impl Service<_> { let unstable_room_versions = vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5]; let mut s = Self { - globals, + db, config, keypair: Arc::new(keypair), dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| { @@ -181,7 +150,6 @@ impl Service<_> { tls_name_override, federation_client, default_client, - server_signingkeys, jwt_decoding_key, stable_room_versions, unstable_room_versions, diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs new file mode 100644 index 00000000..6f6359eb --- /dev/null +++ b/src/service/key_backups/data.rs @@ -0,0 +1,85 @@ +use std::collections::BTreeMap; + +use ruma::{api::client::backup::{BackupAlgorithm, RoomKeyBackup, KeyBackupData}, serde::Raw, UserId, RoomId}; +use crate::Result; + +pub trait Data { + fn create_backup( + &self, + user_id: &UserId, + backup_metadata: &Raw, + ) -> Result; + + fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; + + fn update_backup( + &self, + user_id: &UserId, + version: &str, + backup_metadata: &Raw, + ) -> Result; + + fn get_latest_backup_version(&self, user_id: &UserId) -> Result>; + + fn get_latest_backup( + &self, + user_id: &UserId, + ) -> Result)>>; + + fn get_backup( + &self, + user_id: &UserId, + version: &str, + ) -> Result>>; + + fn add_key( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + key_data: &Raw, + ) -> Result<()>; + + fn count_keys(&self, user_id: &UserId, version: &str) -> Result; + + fn get_etag(&self, user_id: &UserId, version: &str) -> Result; + + fn get_all( + &self, + user_id: &UserId, + version: &str, + ) -> Result, RoomKeyBackup>>; + + fn get_room( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result>>; + + fn get_session( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + ) -> Result>>; + + fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>; + + fn delete_room_keys( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result<()>; + + fn delete_room_key( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + ) -> Result<()>; +} diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs new file mode 100644 index 00000000..8e842d4e --- /dev/null +++ b/src/service/key_backups/mod.rs @@ -0,0 +1,378 @@ +mod data; +pub use data::Data; + +use crate::{utils, Error, Result, services}; +use ruma::{ + api::client::{ + backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, + error::ErrorKind, + }, + serde::Raw, + RoomId, UserId, +}; +use std::{collections::BTreeMap, sync::Arc}; + +pub struct Service { + db: D, +} + +impl Service { + pub fn create_backup( + &self, + user_id: &UserId, + backup_metadata: &Raw, + ) -> Result { + let version = services().globals.next_count()?.to_string(); + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + + self.backupid_algorithm.insert( + &key, + &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), + )?; + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + Ok(version) + } + + pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + + self.backupid_algorithm.remove(&key)?; + self.backupid_etag.remove(&key)?; + + key.push(0xff); + + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } + + Ok(()) + } + + pub fn update_backup( + &self, + user_id: &UserId, + version: &str, + backup_metadata: &Raw, + ) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + + if self.backupid_algorithm.get(&key)?.is_none() { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Tried to update nonexistent backup.", + )); + } + + self.backupid_algorithm + .insert(&key, backup_metadata.json().get().as_bytes())?; + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + Ok(version.to_owned()) + } + + pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.backupid_algorithm + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|(key, _)| { + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) + }) + .transpose() + } + + pub fn get_latest_backup( + &self, + user_id: &UserId, + ) -> Result)>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.backupid_algorithm + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|(key, value)| { + let version = utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; + + Ok(( + version, + serde_json::from_slice(&value).map_err(|_| { + Error::bad_database("Algorithm in backupid_algorithm is invalid.") + })?, + )) + }) + .transpose() + } + + pub fn get_backup( + &self, + user_id: &UserId, + version: &str, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + + self.backupid_algorithm + .get(&key)? + .map_or(Ok(None), |bytes| { + serde_json::from_slice(&bytes) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) + }) + } + + pub fn add_key( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + key_data: &Raw, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + + if self.backupid_algorithm.get(&key)?.is_none() { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Tried to update nonexistent backup.", + )); + } + + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(session_id.as_bytes()); + + self.backupkeyid_backup + .insert(&key, key_data.json().get().as_bytes())?; + + Ok(()) + } + + pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(version.as_bytes()); + + Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) + } + + pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + + Ok(utils::u64_from_bytes( + &self + .backupid_etag + .get(&key)? + .ok_or_else(|| Error::bad_database("Backup has no etag."))?, + ) + .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? + .to_string()) + } + + pub fn get_all( + &self, + user_id: &UserId, + version: &str, + ) -> Result, RoomKeyBackup>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(version.as_bytes()); + prefix.push(0xff); + + let mut rooms = BTreeMap::, RoomKeyBackup>::new(); + + for result in self + .backupkeyid_backup + .scan_prefix(prefix) + .map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xff); + + let session_id = + utils::string_from_bytes(parts.next().ok_or_else(|| { + Error::bad_database("backupkeyid_backup key is invalid.") + })?) + .map_err(|_| { + Error::bad_database("backupkeyid_backup session_id is invalid.") + })?; + + let room_id = RoomId::parse( + utils::string_from_bytes(parts.next().ok_or_else(|| { + Error::bad_database("backupkeyid_backup key is invalid.") + })?) + .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, + ) + .map_err(|_| { + Error::bad_database("backupkeyid_backup room_id is invalid room id.") + })?; + + let key_data = serde_json::from_slice(&value).map_err(|_| { + Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") + })?; + + Ok::<_, Error>((room_id, session_id, key_data)) + }) + { + let (room_id, session_id, key_data) = result?; + rooms + .entry(room_id) + .or_insert_with(|| RoomKeyBackup { + sessions: BTreeMap::new(), + }) + .sessions + .insert(session_id, key_data); + } + + Ok(rooms) + } + + pub fn get_room( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(version.as_bytes()); + prefix.push(0xff); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xff); + + Ok(self + .backupkeyid_backup + .scan_prefix(prefix) + .map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xff); + + let session_id = + utils::string_from_bytes(parts.next().ok_or_else(|| { + Error::bad_database("backupkeyid_backup key is invalid.") + })?) + .map_err(|_| { + Error::bad_database("backupkeyid_backup session_id is invalid.") + })?; + + let key_data = serde_json::from_slice(&value).map_err(|_| { + Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") + })?; + + Ok::<_, Error>((session_id, key_data)) + }) + .filter_map(|r| r.ok()) + .collect()) + } + + pub fn get_session( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(session_id.as_bytes()); + + self.backupkeyid_backup + .get(&key)? + .map(|value| { + serde_json::from_slice(&value).map_err(|_| { + Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") + }) + }) + .transpose() + } + + pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + key.push(0xff); + + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } + + Ok(()) + } + + pub fn delete_room_keys( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xff); + + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } + + Ok(()) + } + + pub fn delete_room_key( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(version.as_bytes()); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(session_id.as_bytes()); + + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } + + Ok(()) + } +} diff --git a/src/service/media.rs b/src/service/media.rs deleted file mode 100644 index 1bdf6d47..00000000 --- a/src/service/media.rs +++ /dev/null @@ -1,357 +0,0 @@ -use image::{imageops::FilterType, GenericImageView}; - -use super::abstraction::Tree; -use crate::{utils, Error, Result}; -use std::{mem, sync::Arc}; -use tokio::{ - fs::File, - io::{AsyncReadExt, AsyncWriteExt}, -}; - -pub struct FileMeta { - pub content_disposition: Option, - pub content_type: Option, - pub file: Vec, -} - -pub struct Media { - pub(super) mediaid_file: Arc, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType -} - -impl Media { - /// Uploads a file. - pub async fn create( - &self, - mxc: String, - globals: &Globals, - content_disposition: &Option<&str>, - content_type: &Option<&str>, - file: &[u8], - ) -> Result<()> { - let mut key = mxc.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail - key.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail - key.push(0xff); - key.extend_from_slice( - content_disposition - .as_ref() - .map(|f| f.as_bytes()) - .unwrap_or_default(), - ); - key.push(0xff); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); - - let path = globals.get_media_file(&key); - let mut f = File::create(path).await?; - f.write_all(file).await?; - - self.mediaid_file.insert(&key, &[])?; - Ok(()) - } - - /// Uploads or replaces a file thumbnail. - #[allow(clippy::too_many_arguments)] - pub async fn upload_thumbnail( - &self, - mxc: String, - globals: &Globals, - content_disposition: &Option, - content_type: &Option, - width: u32, - height: u32, - file: &[u8], - ) -> Result<()> { - let mut key = mxc.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&width.to_be_bytes()); - key.extend_from_slice(&height.to_be_bytes()); - key.push(0xff); - key.extend_from_slice( - content_disposition - .as_ref() - .map(|f| f.as_bytes()) - .unwrap_or_default(), - ); - key.push(0xff); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); - - let path = globals.get_media_file(&key); - let mut f = File::create(path).await?; - f.write_all(file).await?; - - self.mediaid_file.insert(&key, &[])?; - - Ok(()) - } - - /// Downloads a file. - pub async fn get(&self, globals: &Globals, mxc: &str) -> Result> { - let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail - prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail - prefix.push(0xff); - - let first = self.mediaid_file.scan_prefix(prefix).next(); - if let Some((key, _)) = first { - let path = globals.get_media_file(&key); - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; - let mut parts = key.rsplit(|&b| b == 0xff); - - let content_type = parts - .next() - .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Content type in mediaid_file is invalid unicode.") - }) - }) - .transpose()?; - - let content_disposition_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; - - let content_disposition = if content_disposition_bytes.is_empty() { - None - } else { - Some( - utils::string_from_bytes(content_disposition_bytes).map_err(|_| { - Error::bad_database( - "Content Disposition in mediaid_file is invalid unicode.", - ) - })?, - ) - }; - - Ok(Some(FileMeta { - content_disposition, - content_type, - file, - })) - } else { - Ok(None) - } - } - - /// Returns width, height of the thumbnail and whether it should be cropped. Returns None when - /// the server should send the original file. - pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { - match (width, height) { - (0..=32, 0..=32) => Some((32, 32, true)), - (0..=96, 0..=96) => Some((96, 96, true)), - (0..=320, 0..=240) => Some((320, 240, false)), - (0..=640, 0..=480) => Some((640, 480, false)), - (0..=800, 0..=600) => Some((800, 600, false)), - _ => None, - } - } - - /// Downloads a file's thumbnail. - /// - /// Here's an example on how it works: - /// - /// - Client requests an image with width=567, height=567 - /// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails - /// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96) - /// - Server creates the thumbnail and sends it to the user - /// - /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. - pub async fn get_thumbnail( - &self, - mxc: &str, - globals: &Globals, - width: u32, - height: u32, - ) -> Result> { - let (width, height, crop) = self - .thumbnail_properties(width, height) - .unwrap_or((0, 0, false)); // 0, 0 because that's the original file - - let mut main_prefix = mxc.as_bytes().to_vec(); - main_prefix.push(0xff); - - let mut thumbnail_prefix = main_prefix.clone(); - thumbnail_prefix.extend_from_slice(&width.to_be_bytes()); - thumbnail_prefix.extend_from_slice(&height.to_be_bytes()); - thumbnail_prefix.push(0xff); - - let mut original_prefix = main_prefix; - original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail - original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail - original_prefix.push(0xff); - - let first_thumbnailprefix = self.mediaid_file.scan_prefix(thumbnail_prefix).next(); - let first_originalprefix = self.mediaid_file.scan_prefix(original_prefix).next(); - if let Some((key, _)) = first_thumbnailprefix { - // Using saved thumbnail - let path = globals.get_media_file(&key); - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; - let mut parts = key.rsplit(|&b| b == 0xff); - - let content_type = parts - .next() - .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Content type in mediaid_file is invalid unicode.") - }) - }) - .transpose()?; - - let content_disposition_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; - - let content_disposition = if content_disposition_bytes.is_empty() { - None - } else { - Some( - utils::string_from_bytes(content_disposition_bytes).map_err(|_| { - Error::bad_database("Content Disposition in db is invalid.") - })?, - ) - }; - - Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.to_vec(), - })) - } else if let Some((key, _)) = first_originalprefix { - // Generate a thumbnail - let path = globals.get_media_file(&key); - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; - - let mut parts = key.rsplit(|&b| b == 0xff); - - let content_type = parts - .next() - .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Content type in mediaid_file is invalid unicode.") - }) - }) - .transpose()?; - - let content_disposition_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; - - let content_disposition = if content_disposition_bytes.is_empty() { - None - } else { - Some( - utils::string_from_bytes(content_disposition_bytes).map_err(|_| { - Error::bad_database( - "Content Disposition in mediaid_file is invalid unicode.", - ) - })?, - ) - }; - - if let Ok(image) = image::load_from_memory(&file) { - let original_width = image.width(); - let original_height = image.height(); - if width > original_width || height > original_height { - return Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.to_vec(), - })); - } - - let thumbnail = if crop { - image.resize_to_fill(width, height, FilterType::CatmullRom) - } else { - let (exact_width, exact_height) = { - // Copied from image::dynimage::resize_dimensions - let ratio = u64::from(original_width) * u64::from(height); - let nratio = u64::from(width) * u64::from(original_height); - - let use_width = nratio <= ratio; - let intermediate = if use_width { - u64::from(original_height) * u64::from(width) - / u64::from(original_width) - } else { - u64::from(original_width) * u64::from(height) - / u64::from(original_height) - }; - if use_width { - if intermediate <= u64::from(::std::u32::MAX) { - (width, intermediate as u32) - } else { - ( - (u64::from(width) * u64::from(::std::u32::MAX) / intermediate) - as u32, - ::std::u32::MAX, - ) - } - } else if intermediate <= u64::from(::std::u32::MAX) { - (intermediate as u32, height) - } else { - ( - ::std::u32::MAX, - (u64::from(height) * u64::from(::std::u32::MAX) / intermediate) - as u32, - ) - } - }; - - image.thumbnail_exact(exact_width, exact_height) - }; - - let mut thumbnail_bytes = Vec::new(); - thumbnail.write_to(&mut thumbnail_bytes, image::ImageOutputFormat::Png)?; - - // Save thumbnail in database so we don't have to generate it again next time - let mut thumbnail_key = key.to_vec(); - let width_index = thumbnail_key - .iter() - .position(|&b| b == 0xff) - .ok_or_else(|| Error::bad_database("Media in db is invalid."))? - + 1; - let mut widthheight = width.to_be_bytes().to_vec(); - widthheight.extend_from_slice(&height.to_be_bytes()); - - thumbnail_key.splice( - width_index..width_index + 2 * mem::size_of::(), - widthheight, - ); - - let path = globals.get_media_file(&thumbnail_key); - let mut f = File::create(path).await?; - f.write_all(&thumbnail_bytes).await?; - - self.mediaid_file.insert(&thumbnail_key, &[])?; - - Ok(Some(FileMeta { - content_disposition, - content_type, - file: thumbnail_bytes.to_vec(), - })) - } else { - // Couldn't parse file to generate thumbnail, send original - Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.to_vec(), - })) - } - } else { - Ok(None) - } - } -} diff --git a/src/service/media/data.rs b/src/service/media/data.rs new file mode 100644 index 00000000..94975de7 --- /dev/null +++ b/src/service/media/data.rs @@ -0,0 +1,8 @@ +use crate::Result; + +pub trait Data { + fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: &Option<&str>, content_type: &Option<&str>) -> Result>; + + /// Returns content_disposition, content_type and the metadata key. + fn search_file_metadata(&self, mxc: String, width: u32, height: u32) -> Result<(Option, Option, Vec)>; +} diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs new file mode 100644 index 00000000..a5aca036 --- /dev/null +++ b/src/service/media/mod.rs @@ -0,0 +1,206 @@ +mod data; +pub use data::Data; + +use image::{imageops::FilterType, GenericImageView}; +use crate::{utils, Error, Result, services}; +use std::{mem, sync::Arc}; +use tokio::{ + fs::File, + io::{AsyncReadExt, AsyncWriteExt}, +}; + +pub struct FileMeta { + pub content_disposition: Option, + pub content_type: Option, + pub file: Vec, +} + +pub struct Service { + db: D, +} + +impl Service { + /// Uploads a file. + pub async fn create( + &self, + mxc: String, + content_disposition: &Option<&str>, + content_type: &Option<&str>, + file: &[u8], + ) -> Result<()> { + // Width, Height = 0 if it's not a thumbnail + let key = self.db.create_file_metadata(mxc, 0, 0, content_disposition, content_type); + + let path = services().globals.get_media_file(&key); + let mut f = File::create(path).await?; + f.write_all(file).await?; + Ok(()) + } + + /// Uploads or replaces a file thumbnail. + #[allow(clippy::too_many_arguments)] + pub async fn upload_thumbnail( + &self, + mxc: String, + content_disposition: &Option, + content_type: &Option, + width: u32, + height: u32, + file: &[u8], + ) -> Result<()> { + let key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type); + + let path = services().globals.get_media_file(&key); + let mut f = File::create(path).await?; + f.write_all(file).await?; + + Ok(()) + } + + /// Downloads a file. + pub async fn get(&self, mxc: String) -> Result> { + if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) { + let path = services().globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; + + + Ok(Some(FileMeta { + content_disposition, + content_type, + file, + })) + } else { + Ok(None) + } + } + + /// Returns width, height of the thumbnail and whether it should be cropped. Returns None when + /// the server should send the original file. + pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { + match (width, height) { + (0..=32, 0..=32) => Some((32, 32, true)), + (0..=96, 0..=96) => Some((96, 96, true)), + (0..=320, 0..=240) => Some((320, 240, false)), + (0..=640, 0..=480) => Some((640, 480, false)), + (0..=800, 0..=600) => Some((800, 600, false)), + _ => None, + } + } + + /// Downloads a file's thumbnail. + /// + /// Here's an example on how it works: + /// + /// - Client requests an image with width=567, height=567 + /// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails + /// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96) + /// - Server creates the thumbnail and sends it to the user + /// + /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. + pub async fn get_thumbnail( + &self, + mxc: String, + width: u32, + height: u32, + ) -> Result> { + let (width, height, crop) = self + .thumbnail_properties(width, height) + .unwrap_or((0, 0, false)); // 0, 0 because that's the original file + + if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, width, height) { + // Using saved thumbnail + let path = services().globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; + + Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.to_vec(), + })) + } else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) { + // Generate a thumbnail + let path = services().globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; + + if let Ok(image) = image::load_from_memory(&file) { + let original_width = image.width(); + let original_height = image.height(); + if width > original_width || height > original_height { + return Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.to_vec(), + })); + } + + let thumbnail = if crop { + image.resize_to_fill(width, height, FilterType::CatmullRom) + } else { + let (exact_width, exact_height) = { + // Copied from image::dynimage::resize_dimensions + let ratio = u64::from(original_width) * u64::from(height); + let nratio = u64::from(width) * u64::from(original_height); + + let use_width = nratio <= ratio; + let intermediate = if use_width { + u64::from(original_height) * u64::from(width) + / u64::from(original_width) + } else { + u64::from(original_width) * u64::from(height) + / u64::from(original_height) + }; + if use_width { + if intermediate <= u64::from(::std::u32::MAX) { + (width, intermediate as u32) + } else { + ( + (u64::from(width) * u64::from(::std::u32::MAX) / intermediate) + as u32, + ::std::u32::MAX, + ) + } + } else if intermediate <= u64::from(::std::u32::MAX) { + (intermediate as u32, height) + } else { + ( + ::std::u32::MAX, + (u64::from(height) * u64::from(::std::u32::MAX) / intermediate) + as u32, + ) + } + }; + + image.thumbnail_exact(exact_width, exact_height) + }; + + let mut thumbnail_bytes = Vec::new(); + thumbnail.write_to(&mut thumbnail_bytes, image::ImageOutputFormat::Png)?; + + // Save thumbnail in database so we don't have to generate it again next time + let thumbnail_key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type)?; + + let path = services().globals.get_media_file(&thumbnail_key); + let mut f = File::create(path).await?; + f.write_all(&thumbnail_bytes).await?; + + Ok(Some(FileMeta { + content_disposition, + content_type, + file: thumbnail_bytes.to_vec(), + })) + } else { + // Couldn't parse file to generate thumbnail, send original + Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.to_vec(), + })) + } + } else { + Ok(None) + } + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs index 80239cbf..4364c72e 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,28 +1,29 @@ -pub mod pdu; -pub mod appservice; -pub mod pusher; -pub mod rooms; -pub mod transaction_ids; -pub mod uiaa; -pub mod users; pub mod account_data; pub mod admin; +pub mod appservice; pub mod globals; pub mod key_backups; pub mod media; +pub mod pdu; +pub mod pusher; +pub mod rooms; pub mod sending; +pub mod transaction_ids; +pub mod uiaa; +pub mod users; -pub struct Services { +pub struct Services +{ pub appservice: appservice::Service, pub pusher: pusher::Service, pub rooms: rooms::Service, pub transaction_ids: transaction_ids::Service, pub uiaa: uiaa::Service, pub users: users::Service, - //pub account_data: account_data::Service, - //pub admin: admin::Service, + pub account_data: account_data::Service, + pub admin: admin::Service, pub globals: globals::Service, - //pub key_backups: key_backups::Service, - //pub media: media::Service, - //pub sending: sending::Service, + pub key_backups: key_backups::Service, + pub media: media::Service, + pub sending: sending::Service, } diff --git a/src/service/pdu.rs b/src/service/pdu.rs index 47e21a60..2ed79f2c 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -1,4 +1,4 @@ -use crate::{Database, Error, services}; +use crate::{Error, services}; use ruma::{ events::{ room::member::RoomMemberEventContent, AnyEphemeralRoomEvent, AnyRoomEvent, AnyStateEvent, @@ -357,7 +357,7 @@ pub(crate) fn gen_event_id_canonical_json( Ok((event_id, value)) } -/// Build the start of a PDU in order to add it to the `Database`. +/// Build the start of a PDU in order to add it to the Database. #[derive(Debug, Deserialize)] pub struct PduBuilder { #[serde(rename = "type")] diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index ef2b8193..3951da79 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -1,4 +1,5 @@ use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; +use crate::Result; pub trait Data { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>; @@ -10,5 +11,5 @@ pub trait Data { fn get_pusher_senderkeys<'a>( &'a self, sender: &UserId, - ) -> impl Iterator> + 'a; + ) -> Box>>; } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 87e91a14..66a8ae36 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,7 +1,7 @@ mod data; pub use data::Data; -use crate::{services, Error, PduEvent}; +use crate::{services, Error, PduEvent, Result}; use bytes::BytesMut; use ruma::{ api::{ @@ -27,7 +27,7 @@ pub struct Service { db: D, } -impl Service<_> { +impl Service { pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { self.db.set_pusher(sender, pusher) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index 655f32aa..c5d45e36 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,24 +1,29 @@ use ruma::{RoomId, RoomAliasId}; +use crate::Result; pub trait Data { /// Creates or updates the alias to the given room id. fn set_alias( + &self, alias: &RoomAliasId, room_id: &RoomId ) -> Result<()>; /// Forgets about an alias. Returns an error if the alias did not exist. fn remove_alias( + &self, alias: &RoomAliasId, ) -> Result<()>; /// Looks up the roomid for the given alias. fn resolve_local_alias( + &self, alias: &RoomAliasId, - ) -> Result<()>; + ) -> Result>>; /// Returns all local aliases that point to the given room fn local_aliases_for_room( - alias: &RoomAliasId, - ) -> Result<()>; + &self, + room_id: &RoomId, + ) -> Result>>; } diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index f46609aa..abe299d4 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,12 +1,14 @@ mod data; pub use data::Data; + use ruma::{RoomAliasId, RoomId}; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { #[tracing::instrument(skip(self))] pub fn set_alias( &self, @@ -26,7 +28,7 @@ impl Service<_> { #[tracing::instrument(skip(self))] pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result>> { - self.db.resolve_local_alias(alias: &RoomAliasId) + self.db.resolve_local_alias(alias) } #[tracing::instrument(skip(self))] diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 88c86fad..5177d6d6 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,6 +1,7 @@ use std::collections::HashSet; +use crate::Result; pub trait Data { - fn get_cached_eventid_authchain<'a>() -> Result>; - fn cache_eventid_authchain<'a>(shorteventid: u64, auth_chain: &HashSet) -> Result>; + fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result>; + fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet) -> Result<()>; } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index e17c10a1..113d2e81 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -3,13 +3,13 @@ use std::{sync::Arc, collections::HashSet}; pub use data::Data; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { #[tracing::instrument(skip(self))] pub fn get_cached_eventid_authchain<'a>( &'a self, diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index e28cdd12..13767217 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -1,15 +1,16 @@ use ruma::RoomId; +use crate::Result; pub trait Data { /// Adds the room to the public room directory - fn set_public(room_id: &RoomId) -> Result<()>; + fn set_public(&self, room_id: &RoomId) -> Result<()>; /// Removes the room from the public room directory. - fn set_not_public(room_id: &RoomId) -> Result<()>; + fn set_not_public(&self, room_id: &RoomId) -> Result<()>; /// Returns true if the room is in the public room directory. - fn is_public_room(room_id: &RoomId) -> Result; + fn is_public_room(&self, room_id: &RoomId) -> Result; /// Returns the unsorted public room directory - fn public_rooms() -> impl Iterator>> + '_; + fn public_rooms(&self) -> Box>>>; } diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index cb9cda86..68535057 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -2,13 +2,13 @@ mod data; pub use data::Data; use ruma::RoomId; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { #[tracing::instrument(skip(self))] pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) diff --git a/src/service/rooms/edus/mod.rs b/src/service/rooms/edus/mod.rs index 5566fb2c..a5ce37f1 100644 --- a/src/service/rooms/edus/mod.rs +++ b/src/service/rooms/edus/mod.rs @@ -2,7 +2,9 @@ pub mod presence; pub mod read_receipt; pub mod typing; -pub struct Service { +pub trait Data: presence::Data + read_receipt::Data + typing::Data {} + +pub struct Service { presence: presence::Service, read_receipt: read_receipt::Service, typing: typing::Service, diff --git a/src/service/rooms/edus/presence/data.rs b/src/service/rooms/edus/presence/data.rs index 8e3c672f..ca0e2410 100644 --- a/src/service/rooms/edus/presence/data.rs +++ b/src/service/rooms/edus/presence/data.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use ruma::{UserId, RoomId, events::presence::PresenceEvent}; +use crate::Result; pub trait Data { /// Adds a presence event which will be saved until a new event replaces it. diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs index 5a988d4f..646cf549 100644 --- a/src/service/rooms/edus/presence/mod.rs +++ b/src/service/rooms/edus/presence/mod.rs @@ -4,13 +4,13 @@ use std::collections::HashMap; pub use data::Data; use ruma::{RoomId, UserId, events::presence::PresenceEvent}; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Adds a presence event which will be saved until a new event replaces it. /// /// Note: This method takes a RoomId because presence updates are always bound to rooms to diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index 32b091f2..e8ed9656 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -1,4 +1,5 @@ use ruma::{RoomId, events::receipt::ReceiptEvent, UserId, serde::Raw}; +use crate::Result; pub trait Data { /// Replaces the previous read receipt. @@ -14,13 +15,13 @@ pub trait Data { &self, room_id: &RoomId, since: u64, - ) -> impl Iterator< + ) -> Box, u64, Raw, )>, - >; + >>; /// Sets a private read marker at `count`. fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs index 744fece1..3f0b1476 100644 --- a/src/service/rooms/edus/read_receipt/mod.rs +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -1,12 +1,14 @@ mod data; pub use data::Data; + use ruma::{RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw}; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Replaces the previous read receipt. pub fn readreceipt_update( &self, diff --git a/src/service/rooms/edus/typing/data.rs b/src/service/rooms/edus/typing/data.rs index 0c773135..ec0be466 100644 --- a/src/service/rooms/edus/typing/data.rs +++ b/src/service/rooms/edus/typing/data.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; - +use crate::Result; use ruma::{UserId, RoomId}; pub trait Data { @@ -14,5 +14,5 @@ pub trait Data { fn last_typing_update(&self, room_id: &RoomId) -> Result; /// Returns all user ids currently typing. - fn typings_all(&self, room_id: &RoomId) -> Result>; + fn typings_all(&self, room_id: &RoomId) -> Result>>; } diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index 68b9fd83..00cfdecb 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -1,14 +1,14 @@ mod data; pub use data::Data; -use ruma::{UserId, RoomId}; +use ruma::{UserId, RoomId, events::SyncEphemeralRoomEvent}; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is /// called. pub fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 71529570..c9b041c2 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -250,7 +250,7 @@ impl Service { // We go through all the signatures we see on the value and fetch the corresponding signing // keys - self.fetch_required_signing_keys(&value, pub_key_map, db) + self.fetch_required_signing_keys(&value, pub_key_map) .await?; // 2. Check signatures, otherwise drop @@ -1153,6 +1153,11 @@ impl Service { let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec> = initial_set; + let first_pdu_in_room = services() + .rooms + .first_pdu_in_room(room_id)? + .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let mut amount = 0; while let Some(prev_event_id) = todo_outlier_stack.pop() { diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index 52a683d3..5fefd3f8 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -1,4 +1,5 @@ use ruma::{RoomId, DeviceId, UserId}; +use crate::Result; pub trait Data { fn lazy_load_was_sent_before( diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index bdc083a0..283d45af 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -4,13 +4,13 @@ use std::collections::HashSet; pub use data::Data; use ruma::{DeviceId, UserId, RoomId}; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { #[tracing::instrument(skip(self))] pub fn lazy_load_was_sent_before( &self, diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 2d718b2d..9b1ce079 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,4 +1,5 @@ use ruma::RoomId; +use crate::Result; pub trait Data { fn exists(&self, room_id: &RoomId) -> Result; diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 8417e28e..1bdb78d6 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -2,13 +2,13 @@ mod data; pub use data::Data; use ruma::RoomId; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Checks if a room exists. #[tracing::instrument(skip(self))] pub fn exists(&self, room_id: &RoomId) -> Result { diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index 47250340..4da42236 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -16,7 +16,9 @@ pub mod state_compressor; pub mod timeline; pub mod user; -pub struct Service { +pub trait Data: alias::Data + auth_chain::Data + directory::Data + edus::Data + lazy_loading::Data + metadata::Data + outlier::Data + pdu_metadata::Data + search::Data + short::Data + state::Data + state_accessor::Data + state_cache::Data + state_compressor::Data + timeline::Data + user::Data {} + +pub struct Service { pub alias: alias::Service, pub auth_chain: auth_chain::Service, pub directory: directory::Service, diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs index d579515e..17d0f7b4 100644 --- a/src/service/rooms/outlier/data.rs +++ b/src/service/rooms/outlier/data.rs @@ -1,6 +1,6 @@ -use ruma::{EventId, signatures::CanonicalJsonObject}; +use ruma::{signatures::CanonicalJsonObject, EventId}; -use crate::PduEvent; +use crate::{PduEvent, Result}; pub trait Data { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index ee8b940f..a495db8f 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -2,13 +2,13 @@ mod data; pub use data::Data; use ruma::{EventId, signatures::CanonicalJsonObject}; -use crate::{service::*, PduEvent}; +use crate::{Result, PduEvent}; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Returns the pdu from the outlier tree. pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu_json(event_id) diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 531823fe..fb839023 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use ruma::{EventId, RoomId}; +use crate::Result; pub trait Data { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 3442b830..c57c1a28 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -4,13 +4,13 @@ use std::sync::Arc; pub use data::Data; use ruma::{RoomId, EventId}; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { #[tracing::instrument(skip(self, room_id, event_ids))] pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { self.db.mark_as_referenced(room_id, event_ids) diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 16287eba..c0fd2a37 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,11 +1,12 @@ use ruma::RoomId; +use crate::Result; pub trait Data { - fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()>; + fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: u64, message_body: String) -> Result<()>; fn search_pdus<'a>( &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a, Vec)>>; + ) -> Result>>, Vec)>>; } diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 9087deff..b7023f32 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,12 +1,14 @@ mod data; pub use data::Data; + +use crate::Result; use ruma::RoomId; pub struct Service { db: D, } -impl Service<_> { +impl Service { #[tracing::instrument(skip(self))] pub fn search_pdus<'a>( &'a self, diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs new file mode 100644 index 00000000..3b1c3117 --- /dev/null +++ b/src/service/rooms/short/data.rs @@ -0,0 +1,2 @@ +pub trait Data { +} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index afde14e2..1eb891e6 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -2,19 +2,18 @@ mod data; use std::sync::Arc; pub use data::Data; -use ruma::{EventId, events::StateEventType}; +use ruma::{EventId, events::StateEventType, RoomId}; -use crate::{service::*, Error, utils}; +use crate::{Result, Error, utils, services}; pub struct Service { db: D, } -impl Service<_> { +impl Service { pub fn get_or_create_shorteventid( &self, event_id: &EventId, - globals: &super::globals::Globals, ) -> Result { if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { return Ok(*short); @@ -24,7 +23,7 @@ impl Service<_> { Some(shorteventid) => utils::u64_from_bytes(&shorteventid) .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, None => { - let shorteventid = globals.next_count()?; + let shorteventid = services().globals.next_count()?; self.eventid_shorteventid .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; self.shorteventid_eventid @@ -82,7 +81,6 @@ impl Service<_> { &self, event_type: &StateEventType, state_key: &str, - globals: &super::globals::Globals, ) -> Result { if let Some(short) = self .statekeyshort_cache @@ -101,7 +99,7 @@ impl Service<_> { Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, None => { - let shortstatekey = globals.next_count()?; + let shortstatekey = services().globals.next_count()?; self.statekey_shortstatekey .insert(&statekey, &shortstatekey.to_be_bytes())?; self.shortstatekey_statekey @@ -190,7 +188,7 @@ impl Service<_> { /// Returns (shortstatehash, already_existed) fn get_or_create_shortstatehash( &self, - state_hash: &StateHashId, + state_hash: &[u8], ) -> Result<(u64, bool)> { Ok(match self.statehash_shortstatehash.get(state_hash)? { Some(shortstatehash) => ( @@ -199,7 +197,7 @@ impl Service<_> { true, ), None => { - let shortstatehash = globals.next_count()?; + let shortstatehash = services().globals.next_count()?; self.statehash_shortstatehash .insert(state_hash, &shortstatehash.to_be_bytes())?; (shortstatehash, false) @@ -220,13 +218,12 @@ impl Service<_> { pub fn get_or_create_shortroomid( &self, room_id: &RoomId, - globals: &super::globals::Globals, ) -> Result { Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { Some(short) => utils::u64_from_bytes(&short) .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, None => { - let short = globals.next_count()?; + let short = services().globals.next_count()?; self.roomid_shortroomid .insert(room_id.as_bytes(), &short.to_be_bytes())?; short diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index ac8fac21..fd0de282 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,30 +1,28 @@ use std::sync::Arc; use std::{sync::MutexGuard, collections::HashSet}; use std::fmt::Debug; - +use crate::Result; use ruma::{EventId, RoomId}; pub trait Data { /// Returns the last state hash key added to the db for the given room. - fn get_room_shortstatehash(room_id: &RoomId); + fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result>; /// Update the current state of the room. - fn set_room_state(room_id: &RoomId, new_shortstatehash: u64, - _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex - ); + fn set_room_state(&self, room_id: &RoomId, new_shortstatehash: u64, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()>; /// Associates a state with an event. - fn set_event_state(shorteventid: u64, shortstatehash: u64) -> Result<()>; + fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()>; /// Returns all events we would send as the prev_events of the next event. - fn get_forward_extremities(room_id: &RoomId) -> Result>>; + fn get_forward_extremities(&self, room_id: &RoomId) -> Result>>; /// Replace the forward extremities of the room. - fn set_forward_extremities( + fn set_forward_extremities<'a>(&self, room_id: &RoomId, - event_ids: impl IntoIterator + Debug, - _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + event_ids: impl IntoIterator + Debug, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()>; } - -pub struct StateLock; diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 6c33d521..e6b5ce20 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -6,13 +6,15 @@ use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEv use serde::Deserialize; use tracing::warn; -use crate::{service::*, SERVICE, PduEvent, Error, utils::calculate_hash}; +use crate::{Result, services, PduEvent, Error, utils::calculate_hash}; + +use super::state_compressor::CompressedStateEvent; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Set the room to the given statehash and update caches. pub fn force_state( &self, @@ -23,11 +25,11 @@ impl Service<_> { ) -> Result<()> { for event_id in statediffnew.into_iter().filter_map(|new| { - SERVICE.rooms.state_compressor.parse_compressed_state_event(new) + services().rooms.state_compressor.parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { - let pdu = match SERVICE.rooms.timeline.get_pdu_json(&event_id)? { + let pdu = match services().rooms.timeline.get_pdu_json(&event_id)? { Some(pdu) => pdu, None => continue, }; @@ -63,10 +65,10 @@ impl Service<_> { Err(_) => continue, }; - SERVICE.room.state_cache.update_membership(room_id, &user_id, membership, &pdu.sender, None, false)?; + services().room.state_cache.update_membership(room_id, &user_id, membership, &pdu.sender, None, false)?; } - SERVICE.room.state_cache.update_joined_count(room_id)?; + services().room.state_cache.update_joined_count(room_id)?; self.db.set_room_state(room_id, shortstatehash); @@ -84,7 +86,7 @@ impl Service<_> { room_id: &RoomId, state_ids_compressed: HashSet, ) -> Result<()> { - let shorteventid = SERVICE.short.get_or_create_shorteventid(event_id)?; + let shorteventid = services().short.get_or_create_shorteventid(event_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; @@ -96,11 +98,11 @@ impl Service<_> { ); let (shortstatehash, already_existed) = - SERVICE.short.get_or_create_shortstatehash(&state_hash)?; + services().short.get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| SERVICE.room.state_compressor.load_shortstatehash_info(p))?; + .map_or_else(|| Ok(Vec::new()), |p| services().room.state_compressor.load_shortstatehash_info(p))?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -119,7 +121,7 @@ impl Service<_> { } else { (state_ids_compressed, HashSet::new()) }; - SERVICE.room.state_compressor.save_state_from_diff( + services().room.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, @@ -176,7 +178,7 @@ impl Service<_> { } // TODO: statehash with deterministic inputs - let shortstatehash = SERVICE.globals.next_count()?; + let shortstatehash = services().globals.next_count()?; let mut statediffnew = HashSet::new(); statediffnew.insert(new); @@ -273,4 +275,8 @@ impl Service<_> { .ok_or_else(|| Error::BadDatabase("Invalid room version"))?; Ok(room_version) } + + pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { + self.db.get_room_shortstatehash(room_id) + } } diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index bf2972f9..48031e49 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,9 +1,11 @@ -use std::{sync::Arc, collections::HashMap}; +use std::{sync::Arc, collections::{HashMap, BTreeMap}}; +use async_trait::async_trait; use ruma::{EventId, events::StateEventType, RoomId}; -use crate::PduEvent; +use crate::{Result, PduEvent}; +#[async_trait] pub trait Data { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 92e5c8e1..5d6886d9 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -4,13 +4,13 @@ use std::{sync::Arc, collections::{HashMap, BTreeMap}}; pub use data::Data; use ruma::{events::StateEventType, RoomId, EventId}; -use crate::{service::*, PduEvent}; +use crate::{Result, PduEvent}; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self))] diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index f6519196..b45b2ea0 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,5 +1,9 @@ -use ruma::{UserId, RoomId}; +use ruma::{UserId, RoomId, serde::Raw, events::AnyStrippedStateEvent}; +use crate::Result; pub trait Data { - fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_invited(&self, user_id: &UserId, room_id: &RoomId, last_state: Option>>) -> Result<()>; + fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index d29501a6..c3b4eb91 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -5,13 +5,13 @@ pub use data::Data; use regex::Regex; use ruma::{RoomId, UserId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType, tag::TagEvent, RoomAccountDataEventType, GlobalAccountDataEventType, direct::DirectEvent, ignored_user_list::IgnoredUserListEvent, AnySyncStateEvent}, serde::Raw, ServerName}; -use crate::{service::*, SERVICE, utils, Error}; +use crate::{Result, services, utils, Error}; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] pub fn update_membership( @@ -24,8 +24,8 @@ impl Service<_> { update_joined_count: bool, ) -> Result<()> { // Keep track what remote users exist by adding them as "deactivated" users - if user_id.server_name() != SERVICE.globals.server_name() { - SERVICE.users.create(user_id, None)?; + if user_id.server_name() != services().globals.server_name() { + services().users.create(user_id, None)?; // TODO: displayname, avatar url } @@ -37,10 +37,6 @@ impl Service<_> { serverroom_id.push(0xff); serverroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - match &membership { MembershipState::Join => { // Check if the user never joined this room @@ -80,24 +76,23 @@ impl Service<_> { // .ok(); // Copy old tags to new room - if let Some(tag_event) = db.account_data.get::( + if let Some(tag_event) = services().account_data.get::( Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag, )? { - SERVICE.account_data + services().account_data .update( Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event, - &db.globals, ) .ok(); }; // Copy direct chat flag - if let Some(mut direct_event) = SERVICE.account_data.get::( + if let Some(mut direct_event) = services().account_data.get::( None, user_id, GlobalAccountDataEventType::Direct.to_string().into(), @@ -112,7 +107,7 @@ impl Service<_> { } if room_ids_updated { - SERVICE.account_data.update( + services().account_data.update( None, user_id, GlobalAccountDataEventType::Direct.to_string().into(), @@ -123,16 +118,11 @@ impl Service<_> { } } - self.userroomid_joined.insert(&userroom_id, &[])?; - self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; + self.db.mark_as_joined(user_id, room_id)?; } MembershipState::Invite => { // We want to know if the sender is ignored by the receiver - let is_ignored = SERVICE + let is_ignored = services() .account_data .get::( None, // Ignored users are in global account data @@ -153,41 +143,22 @@ impl Service<_> { return Ok(()); } - self.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()) - .expect("state to bytes always works"), - )?; - self.roomuserid_invitecount - .insert(&roomuser_id, &db.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; + self.db.mark_as_invited(user_id, room_id, last_state)?; } MembershipState::Leave | MembershipState::Ban => { - self.userroomid_leftstate.insert( - &userroom_id, - &serde_json::to_vec(&Vec::>::new()).unwrap(), - )?; // TODO - self.roomuserid_leftcount - .insert(&roomuser_id, &db.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; + self.db.mark_as_left(user_id, room_id)?; } _ => {} } if update_joined_count { - self.update_joined_count(room_id, db)?; + self.update_joined_count(room_id)?; } Ok(()) } - #[tracing::instrument(skip(self, room_id, db))] + #[tracing::instrument(skip(self, room_id))] pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { let mut joinedcount = 0_u64; let mut invitedcount = 0_u64; @@ -196,8 +167,8 @@ impl Service<_> { for joined in self.room_members(room_id).filter_map(|r| r.ok()) { joined_servers.insert(joined.server_name().to_owned()); - if joined.server_name() == db.globals.server_name() - && !db.users.is_deactivated(&joined).unwrap_or(true) + if joined.server_name() == services().globals.server_name() + && !services().users.is_deactivated(&joined).unwrap_or(true) { real_users.insert(joined); } @@ -285,7 +256,7 @@ impl Service<_> { .get("sender_localpart") .and_then(|string| string.as_str()) .and_then(|string| { - UserId::parse_with_server_name(string, SERVICE.globals.server_name()).ok() + UserId::parse_with_server_name(string, services().globals.server_name()).ok() }); let in_room = bridge_user_id diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 74a28e7b..17689364 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,4 +1,5 @@ -use crate::service::rooms::CompressedStateEvent; +use super::CompressedStateEvent; +use crate::Result; pub struct StateDiff { parent: Option, @@ -7,6 +8,6 @@ pub struct StateDiff { } pub trait Data { - fn get_statediff(shortstatehash: u64) -> Result; - fn save_statediff(shortstatehash: u64, diff: StateDiff) -> Result<()>; + fn get_statediff(&self, shortstatehash: u64) -> Result; + fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 3aea4fe6..619e4cf5 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -4,7 +4,7 @@ use std::{mem::size_of, sync::Arc, collections::HashSet}; pub use data::Data; use ruma::{EventId, RoomId}; -use crate::{service::*, utils}; +use crate::{Result, utils, services}; use self::data::StateDiff; @@ -12,7 +12,9 @@ pub struct Service { db: D, } -impl Service<_> { +pub type CompressedStateEvent = [u8; 2 * size_of::()]; + +impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. #[tracing::instrument(skip(self))] pub fn load_shortstatehash_info( @@ -62,12 +64,11 @@ impl Service<_> { &self, shortstatekey: u64, event_id: &EventId, - globals: &super::globals::Globals, ) -> Result { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( &self - .get_or_create_shorteventid(event_id, globals)? + .get_or_create_shorteventid(event_id)? .to_be_bytes(), ); Ok(v.try_into().expect("we checked the size above")) @@ -210,15 +211,16 @@ impl Service<_> { /// Returns the new shortstatehash pub fn save_state( + &self, room_id: &RoomId, new_state_ids_compressed: HashSet, ) -> Result<(u64, HashSet, // added HashSet)> // removed { - let previous_shortstatehash = self.d.current_shortstatehash(room_id)?; + let previous_shortstatehash = self.db.current_shortstatehash(room_id)?; - let state_hash = self.calculate_hash( + let state_hash = utils::calculate_hash( &new_state_ids_compressed .iter() .map(|bytes| &bytes[..]) @@ -226,7 +228,7 @@ impl Service<_> { ); let (new_shortstatehash, already_existed) = - self.get_or_create_shortstatehash(&state_hash, &db.globals)?; + services().rooms.short.get_or_create_shortstatehash(&state_hash)?; if Some(new_shortstatehash) == previous_shortstatehash { return Ok(()); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index bf6d8c5e..85bedc69 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use ruma::{signatures::CanonicalJsonObject, EventId, UserId, RoomId}; -use crate::PduEvent; +use crate::{Result, PduEvent}; pub trait Data { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; @@ -48,28 +48,26 @@ pub trait Data { /// Returns an iterator over all events in a room that happened after the event with id `since` /// in chronological order. - #[tracing::instrument(skip(self))] fn pdus_since<'a>( &'a self, user_id: &UserId, room_id: &RoomId, since: u64, - ) -> Result, PduEvent)>> + 'a>; + ) -> Result, PduEvent)>>>>; /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. - #[tracing::instrument(skip(self))] fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, until: u64, - ) -> Result, PduEvent)>> + 'a>; + ) -> Result, PduEvent)>>>>; fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, from: u64, - ) -> Result, PduEvent)>> + 'a>; + ) -> Result, PduEvent)>>>>; } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 7b60fe5d..09f66ddf 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,23 +1,29 @@ mod data; +use std::borrow::Cow; +use std::sync::Arc; use std::{sync::MutexGuard, iter, collections::HashSet}; use std::fmt::Debug; pub use data::Data; use regex::Regex; +use ruma::events::room::power_levels::RoomPowerLevelsEventContent; +use ruma::push::Ruleset; use ruma::signatures::CanonicalJsonValue; +use ruma::state_res::RoomVersion; use ruma::{EventId, signatures::CanonicalJsonObject, push::{Action, Tweak}, events::{push_rules::PushRulesEvent, GlobalAccountDataEventType, RoomEventType, room::{member::MembershipState, create::RoomCreateEventContent}, StateEventType}, UserId, RoomAliasId, RoomId, uint, state_res, api::client::error::ErrorKind, serde::to_canonical_value, ServerName}; use serde::Deserialize; use serde_json::value::to_raw_value; use tracing::{warn, error}; -use crate::SERVICE; -use crate::{service::{*, pdu::{PduBuilder, EventHash}}, Error, PduEvent, utils}; +use crate::{services, Result, service::pdu::{PduBuilder, EventHash}, Error, PduEvent, utils}; + +use super::state_compressor::CompressedStateEvent; pub struct Service { db: D, } -impl Service<_> { +impl Service { /* /// Checks if a room exists. #[tracing::instrument(skip(self))] @@ -44,7 +50,7 @@ impl Service<_> { #[tracing::instrument(skip(self))] pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - self.db.last_timeline_count(sender_user: &UserId, room_id: &RoomId) + self.db.last_timeline_count(sender_user, room_id) } // TODO Is this the same as the function above? @@ -127,7 +133,7 @@ impl Service<_> { /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self))] fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { - self.db.pdu_count(pdu_id, pdu: &PduEvent) + self.db.replace_pdu(pdu_id, pdu) } /// Creates a new persisted data unit and adds it to a room. @@ -177,7 +183,7 @@ impl Service<_> { self.replace_pdu_leaves(&pdu.room_id, leaves)?; let mutex_insert = Arc::clone( - db.globals + services().globals .roomid_mutex_insert .write() .unwrap() @@ -186,14 +192,14 @@ impl Service<_> { ); let insert_lock = mutex_insert.lock().unwrap(); - let count1 = db.globals.next_count()?; + let count1 = services().globals.next_count()?; // Mark as read first so the sending client doesn't get a notification even if appending // fails self.edus - .private_read_set(&pdu.room_id, &pdu.sender, count1, &db.globals)?; + .private_read_set(&pdu.room_id, &pdu.sender, count1)?; self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; - let count2 = db.globals.next_count()?; + let count2 = services().globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); @@ -218,7 +224,7 @@ impl Service<_> { drop(insert_lock); // See if the event matches any known pushers - let power_levels: RoomPowerLevelsEventContent = db + let power_levels: RoomPowerLevelsEventContent = services() .rooms .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { @@ -233,13 +239,13 @@ impl Service<_> { let mut notifies = Vec::new(); let mut highlights = Vec::new(); - for user in self.get_our_real_users(&pdu.room_id, db)?.iter() { + for user in self.get_our_real_users(&pdu.room_id)?.iter() { // Don't notify the user of their own events if user == &pdu.sender { continue; } - let rules_for_user = db + let rules_for_user = services() .account_data .get( None, @@ -252,7 +258,7 @@ impl Service<_> { let mut highlight = false; let mut notify = false; - for action in pusher::get_actions( + for action in services().pusher.get_actions( user, &rules_for_user, &power_levels, @@ -282,8 +288,8 @@ impl Service<_> { highlights.push(userroom_id); } - for senderkey in db.pusher.get_pusher_senderkeys(user) { - db.sending.send_push_pdu(&*pdu_id, senderkey)?; + for senderkey in services().pusher.get_pusher_senderkeys(user) { + services().sending.send_push_pdu(&*pdu_id, senderkey)?; } } @@ -328,7 +334,6 @@ impl Service<_> { content.membership, &pdu.sender, invite_state, - db, true, )?; } @@ -344,34 +349,34 @@ impl Service<_> { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - DB.rooms.search.index_pdu(room_id, pdu_id, body)?; + services().rooms.search.index_pdu(shortroomid, pdu_id, body)?; - let admin_room = self.id_from_alias( + let admin_room = self.alias.resolve_local_alias( <&RoomAliasId>::try_from( - format!("#admins:{}", db.globals.server_name()).as_str(), + format!("#admins:{}", services().globals.server_name()).as_str(), ) .expect("#admins:server_name is a valid room alias"), )?; - let server_user = format!("@conduit:{}", db.globals.server_name()); + let server_user = format!("@conduit:{}", services().globals.server_name()); let to_conduit = body.starts_with(&format!("{}: ", server_user)); // This will evaluate to false if the emergency password is set up so that // the administrator can execute commands as conduit let from_conduit = - pdu.sender == server_user && db.globals.emergency_password().is_none(); + pdu.sender == server_user && services().globals.emergency_password().is_none(); if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { - db.admin.process_message(body.to_string()); + services().admin.process_message(body.to_string()); } } } _ => {} } - for appservice in db.appservice.all()? { - if self.appservice_in_room(room_id, &appservice, db)? { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + for appservice in services().appservice.all()? { + if self.appservice_in_room(&pdu.room_id, &appservice)? { + services().sending.send_pdu_appservice(&appservice.0, &pdu_id)?; continue; } @@ -388,11 +393,11 @@ impl Service<_> { .get("sender_localpart") .and_then(|string| string.as_str()) .and_then(|string| { - UserId::parse_with_server_name(string, db.globals.server_name()).ok() + UserId::parse_with_server_name(string, services().globals.server_name()).ok() }) { if state_key_uid == &appservice_uid { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + services().sending.send_pdu_appservice(&appservice.0, &pdu_id)?; continue; } } @@ -431,16 +436,16 @@ impl Service<_> { .map_or(false, |state_key| users.is_match(state_key)) }; let matching_aliases = |aliases: &Regex| { - self.room_aliases(room_id) + self.room_aliases(&pdu.room_id) .filter_map(|r| r.ok()) .any(|room_alias| aliases.is_match(room_alias.as_str())) }; if aliases.iter().any(matching_aliases) - || rooms.map_or(false, |rooms| rooms.contains(&room_id.as_str().into())) + || rooms.map_or(false, |rooms| rooms.contains(&pdu.room_id.as_str().into())) || users.iter().any(matching_users) { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + services().sending.send_pdu_appservice(&appservice.0, &pdu_id)?; } } } @@ -464,14 +469,14 @@ impl Service<_> { redacts, } = pdu_builder; - let prev_events: Vec<_> = SERVICE + let prev_events: Vec<_> = services() .rooms .get_pdu_leaves(room_id)? .into_iter() .take(20) .collect(); - let create_event = SERVICE + let create_event = services() .rooms .room_state_get(room_id, &StateEventType::RoomCreate, "")?; @@ -488,7 +493,7 @@ impl Service<_> { // If there was no create event yet, assume we are creating a room with the default // version right now let room_version_id = create_event_content - .map_or(SERVICE.globals.default_room_version(), |create_event| { + .map_or(services().globals.default_room_version(), |create_event| { create_event.room_version }); let room_version = @@ -500,7 +505,7 @@ impl Service<_> { // Our depth is the maximum depth of prev_events + 1 let depth = prev_events .iter() - .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) + .filter_map(|event_id| Some(services().rooms.get_pdu(event_id).ok()??.depth)) .max() .unwrap_or_else(|| uint!(0)) + uint!(1); @@ -525,7 +530,7 @@ impl Service<_> { let pdu = PduEvent { event_id: ruma::event_id!("$thiswillbefilledinlater").into(), room_id: room_id.to_owned(), - sender: sender_user.to_owned(), + sender: sender.to_owned(), origin_server_ts: utils::millis_since_unix_epoch() .try_into() .expect("time is valid"), @@ -577,13 +582,13 @@ impl Service<_> { // Add origin because synapse likes that (and it's required in the spec) pdu_json.insert( "origin".to_owned(), - to_canonical_value(db.globals.server_name()) + to_canonical_value(services().globals.server_name()) .expect("server name is a valid CanonicalJsonValue"), ); match ruma::signatures::hash_and_sign_event( - SERVICE.globals.server_name().as_str(), - SERVICE.globals.keypair(), + services().globals.server_name().as_str(), + services().globals.keypair(), &mut pdu_json, &room_version_id, ) { @@ -616,22 +621,20 @@ impl Service<_> { ); // Generate short event id - let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; + let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id)?; } /// Creates a new persisted data unit and adds it to a room. This function takes a /// roomid_mutex_state, meaning that only this function is able to mutate the room state. - #[tracing::instrument(skip(self, _mutex_lock))] + #[tracing::instrument(skip(self, state_lock))] pub fn build_and_append_pdu( &self, pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result> { - - let (pdu, pdu_json) = self.create_hash_and_sign_event()?; - + let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, &state_lock); // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. @@ -664,9 +667,9 @@ impl Service<_> { } // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above - servers.remove(SERVICE.globals.server_name()); + servers.remove(services().globals.server_name()); - SERVICE.sending.send_pdu(servers.into_iter(), &pdu_id)?; + services().sending.send_pdu(servers.into_iter(), &pdu_id)?; Ok(pdu.event_id) } @@ -684,20 +687,20 @@ impl Service<_> { ) -> Result>> { // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - SERVICE.rooms.set_event_state( + services().rooms.set_event_state( &pdu.event_id, &pdu.room_id, state_ids_compressed, )?; if soft_fail { - SERVICE.rooms + services().rooms .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - SERVICE.rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; + services().rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; return Ok(None); } - let pdu_id = SERVICE.rooms.append_pdu(pdu, pdu_json, new_room_leaves)?; + let pdu_id = services().rooms.append_pdu(pdu, pdu_json, new_room_leaves)?; Ok(Some(pdu_id)) } diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 47a44eef..a5657bc1 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,3 +1,6 @@ +use ruma::{UserId, RoomId}; +use crate::Result; + pub trait Data { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; @@ -17,5 +20,5 @@ pub trait Data { fn get_shared_rooms<'a>( &'a self, users: Vec>, - ) -> Result>> + 'a>; + ) -> Result>>>>; } diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 664f8a0a..729887c3 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -2,13 +2,13 @@ mod data; pub use data::Data; use ruma::{RoomId, UserId}; -use crate::service::*; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.reset_notification_counts(user_id, room_id) } @@ -27,7 +27,7 @@ impl Service<_> { token: u64, shortstatehash: u64, ) -> Result<()> { - self.db.associate_token_shortstatehash(user_id, room_id) + self.db.associate_token_shortstatehash(room_id, token, shortstatehash) } pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { diff --git a/src/service/sending.rs b/src/service/sending/mod.rs similarity index 88% rename from src/service/sending.rs rename to src/service/sending/mod.rs index 4c830d6f..8ab557f6 100644 --- a/src/service/sending.rs +++ b/src/service/sending/mod.rs @@ -6,7 +6,7 @@ use std::{ }; use crate::{ - appservice_server, database::pusher, server_server, utils, Database, Error, PduEvent, Result, + utils, Error, PduEvent, Result, services, api::{server_server, appservice_server}, }; use federation::transactions::send_transaction_message; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -34,8 +34,6 @@ use tokio::{ }; use tracing::{error, warn}; -use super::abstraction::Tree; - #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum OutgoingKind { Appservice(String), @@ -77,11 +75,8 @@ pub enum SendingEventType { Edu(Vec), } -pub struct Sending { +pub struct Service { /// The state for a given state hash. - pub(super) servername_educount: Arc, // EduCount: Count of last EDU sync - pub(super) servernameevent_data: Arc, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content - pub(super) servercurrentevent_data: Arc, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content pub(super) maximum_requests: Arc, pub sender: mpsc::UnboundedSender<(Vec, Vec)>, } @@ -92,10 +87,9 @@ enum TransactionStatus { Retrying(u32), // number of times failed } -impl Sending { +impl Service { pub fn start_handler( &self, - db: Arc>, mut receiver: mpsc::UnboundedReceiver<(Vec, Vec)>, ) { tokio::spawn(async move { @@ -106,9 +100,7 @@ impl Sending { // Retry requests we could not finish yet let mut initial_transactions = HashMap::>::new(); - let guard = db.read().await; - - for (key, outgoing_kind, event) in guard + for (key, outgoing_kind, event) in services() .sending .servercurrentevent_data .iter() @@ -127,22 +119,19 @@ impl Sending { "Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event ); - guard.sending.servercurrentevent_data.remove(&key).unwrap(); + services().sending.servercurrentevent_data.remove(&key).unwrap(); continue; } entry.push(event); } - drop(guard); - for (outgoing_kind, events) in initial_transactions { current_transaction_status .insert(outgoing_kind.get_prefix(), TransactionStatus::Running); futures.push(Self::handle_events( outgoing_kind.clone(), events, - Arc::clone(&db), )); } @@ -151,17 +140,15 @@ impl Sending { Some(response) = futures.next() => { match response { Ok(outgoing_kind) => { - let guard = db.read().await; - let prefix = outgoing_kind.get_prefix(); - for (key, _) in guard.sending.servercurrentevent_data + for (key, _) in services().sending.servercurrentevent_data .scan_prefix(prefix.clone()) { - guard.sending.servercurrentevent_data.remove(&key).unwrap(); + services().sending.servercurrentevent_data.remove(&key).unwrap(); } // Find events that have been added since starting the last request - let new_events: Vec<_> = guard.sending.servernameevent_data + let new_events: Vec<_> = services().sending.servernameevent_data .scan_prefix(prefix.clone()) .filter_map(|(k, v)| { Self::parse_servercurrentevent(&k, v).ok().map(|ev| (ev, k)) @@ -175,17 +162,14 @@ impl Sending { // Insert pdus we found for (e, key) in &new_events { let value = if let SendingEventType::Edu(value) = &e.1 { &**value } else { &[] }; - guard.sending.servercurrentevent_data.insert(key, value).unwrap(); - guard.sending.servernameevent_data.remove(key).unwrap(); + services().sending.servercurrentevent_data.insert(key, value).unwrap(); + services().sending.servernameevent_data.remove(key).unwrap(); } - drop(guard); - futures.push( Self::handle_events( outgoing_kind.clone(), new_events.into_iter().map(|(event, _)| event.1).collect(), - Arc::clone(&db), ) ); } else { @@ -206,15 +190,12 @@ impl Sending { }, Some((key, value)) = receiver.recv() => { if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key, value) { - let guard = db.read().await; - if let Ok(Some(events)) = Self::select_events( &outgoing_kind, vec![(event, key)], &mut current_transaction_status, - &guard ) { - futures.push(Self::handle_events(outgoing_kind, events, Arc::clone(&db))); + futures.push(Self::handle_events(outgoing_kind, events)); } } } @@ -223,12 +204,11 @@ impl Sending { }); } - #[tracing::instrument(skip(outgoing_kind, new_events, current_transaction_status, db))] + #[tracing::instrument(skip(outgoing_kind, new_events, current_transaction_status))] fn select_events( outgoing_kind: &OutgoingKind, new_events: Vec<(SendingEventType, Vec)>, // Events we want to send: event and full key current_transaction_status: &mut HashMap, TransactionStatus>, - db: &Database, ) -> Result>> { let mut retry = false; let mut allow = true; @@ -266,7 +246,7 @@ impl Sending { if retry { // We retry the previous transaction - for (key, value) in db.sending.servercurrentevent_data.scan_prefix(prefix) { + for (key, value) in services().sending.servercurrentevent_data.scan_prefix(prefix) { if let Ok((_, e)) = Self::parse_servercurrentevent(&key, value) { events.push(e); } @@ -278,22 +258,22 @@ impl Sending { } else { &[][..] }; - db.sending + services().sending .servercurrentevent_data .insert(&full_key, value)?; // If it was a PDU we have to unqueue it // TODO: don't try to unqueue EDUs - db.sending.servernameevent_data.remove(&full_key)?; + services().sending.servernameevent_data.remove(&full_key)?; events.push(e); } if let OutgoingKind::Normal(server_name) = outgoing_kind { - if let Ok((select_edus, last_count)) = Self::select_edus(db, server_name) { + if let Ok((select_edus, last_count)) = Self::select_edus(server_name) { events.extend(select_edus.into_iter().map(SendingEventType::Edu)); - db.sending + services().sending .servername_educount .insert(server_name.as_bytes(), &last_count.to_be_bytes())?; } @@ -303,10 +283,10 @@ impl Sending { Ok(Some(events)) } - #[tracing::instrument(skip(db, server))] - pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec>, u64)> { + #[tracing::instrument(skip(server))] + pub fn select_edus(server: &ServerName) -> Result<(Vec>, u64)> { // u64: count of last edu - let since = db + let since = services() .sending .servername_educount .get(server.as_bytes())? @@ -318,25 +298,25 @@ impl Sending { let mut max_edu_count = since; let mut device_list_changes = HashSet::new(); - 'outer: for room_id in db.rooms.server_rooms(server) { + 'outer: for room_id in services().rooms.server_rooms(server) { let room_id = room_id?; // Look for device list updates in this room device_list_changes.extend( - db.users + services().users .keys_changed(&room_id.to_string(), since, None) .filter_map(|r| r.ok()) - .filter(|user_id| user_id.server_name() == db.globals.server_name()), + .filter(|user_id| user_id.server_name() == services().globals.server_name()), ); // Look for read receipts in this room - for r in db.rooms.edus.readreceipts_since(&room_id, since) { + for r in services().rooms.edus.readreceipts_since(&room_id, since) { let (user_id, count, read_receipt) = r?; if count > max_edu_count { max_edu_count = count; } - if user_id.server_name() != db.globals.server_name() { + if user_id.server_name() != services().globals.server_name() { continue; } @@ -496,14 +476,11 @@ impl Sending { Ok(()) } - #[tracing::instrument(skip(db, events, kind))] + #[tracing::instrument(skip(events, kind))] async fn handle_events( kind: OutgoingKind, events: Vec, - db: Arc>, ) -> Result { - let db = db.read().await; - match &kind { OutgoingKind::Appservice(id) => { let mut pdu_jsons = Vec::new(); @@ -511,7 +488,7 @@ impl Sending { for event in &events { match event { SendingEventType::Pdu(pdu_id) => { - pdu_jsons.push(db.rooms + pdu_jsons.push(services().rooms .get_pdu_from_id(pdu_id) .map_err(|e| (kind.clone(), e))? .ok_or_else(|| { @@ -530,11 +507,10 @@ impl Sending { } } - let permit = db.sending.maximum_requests.acquire().await; + let permit = services().sending.maximum_requests.acquire().await; let response = appservice_server::send_request( - &db.globals, - db.appservice + services().appservice .get_registration(&id) .map_err(|e| (kind.clone(), e))? .ok_or_else(|| { @@ -576,7 +552,7 @@ impl Sending { match event { SendingEventType::Pdu(pdu_id) => { pdus.push( - db.rooms + services().rooms .get_pdu_from_id(pdu_id) .map_err(|e| (kind.clone(), e))? .ok_or_else(|| { @@ -624,7 +600,7 @@ impl Sending { senderkey.push(0xff); senderkey.extend_from_slice(pushkey); - let pusher = match db + let pusher = match services() .pusher .get_pusher(&senderkey) .map_err(|e| (OutgoingKind::Push(user.clone(), pushkey.clone()), e))? @@ -633,7 +609,7 @@ impl Sending { None => continue, }; - let rules_for_user = db + let rules_for_user = services() .account_data .get( None, @@ -644,22 +620,21 @@ impl Sending { .map(|ev: PushRulesEvent| ev.content.global) .unwrap_or_else(|| push::Ruleset::server_default(&userid)); - let unread: UInt = db + let unread: UInt = services() .rooms .notification_count(&userid, &pdu.room_id) .map_err(|e| (kind.clone(), e))? .try_into() .expect("notifiation count can't go that high"); - let permit = db.sending.maximum_requests.acquire().await; + let permit = services().sending.maximum_requests.acquire().await; - let _response = pusher::send_push_notice( + let _response = services().pusher.send_push_notice( &userid, unread, &pusher, rules_for_user, &pdu, - &db, ) .await .map(|_response| kind.clone()) @@ -678,7 +653,7 @@ impl Sending { SendingEventType::Pdu(pdu_id) => { // TODO: check room version and remove event_id if needed let raw = PduEvent::convert_to_outgoing_federation_event( - db.rooms + services().rooms .get_pdu_json_from_id(pdu_id) .map_err(|e| (OutgoingKind::Normal(server.clone()), e))? .ok_or_else(|| { @@ -700,13 +675,12 @@ impl Sending { } } - let permit = db.sending.maximum_requests.acquire().await; + let permit = services().sending.maximum_requests.acquire().await; let response = server_server::send_request( - &db.globals, &*server, send_transaction_message::v1::Request { - origin: db.globals.server_name(), + origin: services().globals.server_name(), pdus: &pdu_jsons, edus: &edu_jsons, origin_server_ts: MilliSecondsSinceUnixEpoch::now(), @@ -809,10 +783,9 @@ impl Sending { }) } - #[tracing::instrument(skip(self, globals, destination, request))] + #[tracing::instrument(skip(self, destination, request))] pub async fn send_federation_request( &self, - globals: &crate::database::globals::Globals, destination: &ServerName, request: T, ) -> Result @@ -820,16 +793,15 @@ impl Sending { T: Debug, { let permit = self.maximum_requests.acquire().await; - let response = server_server::send_request(globals, destination, request).await; + let response = server_server::send_request(destination, request).await; drop(permit); response } - #[tracing::instrument(skip(self, globals, registration, request))] + #[tracing::instrument(skip(self, registration, request))] pub async fn send_appservice_request( &self, - globals: &crate::database::globals::Globals, registration: serde_yaml::Value, request: T, ) -> Result @@ -837,7 +809,7 @@ impl Sending { T: Debug, { let permit = self.maximum_requests.acquire().await; - let response = appservice_server::send_request(globals, registration, request).await; + let response = appservice_server::send_request(registration, request).await; drop(permit); response diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs index c1b47154..6e71dd46 100644 --- a/src/service/transaction_ids/data.rs +++ b/src/service/transaction_ids/data.rs @@ -1,3 +1,6 @@ +use ruma::{DeviceId, UserId, TransactionId}; +use crate::Result; + pub trait Data { fn add_txnid( &self, diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index 9b76e13b..ea923722 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,14 +1,14 @@ mod data; pub use data::Data; -use ruma::{UserId, DeviceId, TransactionId}; -use crate::service::*; +use ruma::{UserId, DeviceId, TransactionId}; +use crate::Result; pub struct Service { db: D, } -impl Service<_> { +impl Service { pub fn add_txnid( &self, user_id: &UserId, diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index cc943bff..d7fa79d2 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -1,4 +1,5 @@ use ruma::{api::client::uiaa::UiaaInfo, DeviceId, UserId, signatures::CanonicalJsonValue}; +use crate::Result; pub trait Data { fn set_uiaa_request( diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 5e1df8f3..ffdbf356 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,15 +1,16 @@ mod data; pub use data::Data; -use ruma::{api::client::{uiaa::{UiaaInfo, IncomingAuthData, IncomingPassword, AuthType}, error::ErrorKind}, DeviceId, UserId, signatures::CanonicalJsonValue}; + +use ruma::{api::client::{uiaa::{UiaaInfo, IncomingAuthData, IncomingPassword, AuthType, IncomingUserIdentifier}, error::ErrorKind}, DeviceId, UserId, signatures::CanonicalJsonValue}; use tracing::error; -use crate::{service::*, utils, Error, SERVICE}; +use crate::{Result, utils, Error, services, api::client_server::SESSION_ID_LENGTH}; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Creates a new Uiaa session. Make sure the session token is unique. pub fn create( &self, @@ -56,7 +57,7 @@ impl Service<_> { .. }) => { let username = match identifier { - UserIdOrLocalpart(username) => username, + IncomingUserIdentifier::UserIdOrLocalpart(username) => username, _ => { return Err(Error::BadRequest( ErrorKind::Unrecognized, @@ -66,13 +67,13 @@ impl Service<_> { }; let user_id = - UserId::parse_with_server_name(username.clone(), SERVICE.globals.server_name()) + UserId::parse_with_server_name(username.clone(), services().globals.server_name()) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") })?; // Check if password is correct - if let Some(hash) = SERVICE.users.password_hash(&user_id)? { + if let Some(hash) = services().users.password_hash(&user_id)? { let hash_matches = argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 327e0c69..3f87589c 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,8 +1,8 @@ use std::collections::BTreeMap; - +use crate::Result; use ruma::{UserId, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, DeviceKeys, CrossSigningKey}, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition}, MxcUri}; -trait Data { +pub trait Data { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result; @@ -16,7 +16,7 @@ trait Data { fn find_from_token(&self, token: &str) -> Result, String)>>; /// Returns an iterator over all users on this homeserver. - fn iter(&self) -> impl Iterator>> + '_; + fn iter(&self) -> Box>>>; /// Returns a list of local users as list of usernames. /// @@ -69,7 +69,7 @@ trait Data { fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator>> + 'a; + ) -> Box>>>; /// Replaces the access token of one device. fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; @@ -125,7 +125,7 @@ trait Data { user_or_room_id: &str, from: u64, to: Option, - ) -> impl Iterator>> + 'a; + ) -> Box>>>; fn mark_device_key_update( &self, @@ -193,7 +193,7 @@ trait Data { fn all_devices_metadata<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator> + 'a; + ) -> Box>>; /// Creates a new sync filter. Returns the filter id. fn create_filter( diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index bfa4b8e5..dfe6c7fb 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -2,15 +2,15 @@ mod data; use std::{collections::BTreeMap, mem}; pub use data::Data; -use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, DeviceKeyAlgorithm, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition}}; +use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, DeviceKeyAlgorithm, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition, error::ErrorKind}, RoomAliasId}; -use crate::{service::*, Error}; +use crate::{Result, Error, services}; pub struct Service { db: D, } -impl Service<_> { +impl Service { /// Check if a user has an account on this homeserver. pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) @@ -22,19 +22,19 @@ impl Service<_> { } /// Check if a user is an admin - fn is_admin( + pub fn is_admin( &self, user_id: &UserId, ) -> Result { - let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name())) + let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", services().globals.server_name())) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; - let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap(); + let admin_room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias_id)?.unwrap(); - rooms.is_joined(user_id, &admin_room_id) + services().rooms.state_cache.is_joined(user_id, &admin_room_id) } /// Create a new user account on this homeserver. - fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { self.db.set_password(user_id, password)?; Ok(()) }