Work on rooms/state and database

This commit is contained in:
Timo Kösters 2022-06-25 16:12:23 +02:00
parent 03b2867a84
commit 7c166aa468
No known key found for this signature in database
GPG key ID: 356E705610F626D5
6 changed files with 144 additions and 136 deletions

View file

@ -26,7 +26,7 @@ pub mod persy;
))]
pub mod watchers;
pub trait DatabaseEngine: Send + Sync {
pub trait KeyValueDatabaseEngine: Send + Sync {
fn open(config: &Config) -> Result<Self>
where
Self: Sized;
@ -40,7 +40,7 @@ pub trait DatabaseEngine: Send + Sync {
}
}
pub trait Tree: Send + Sync {
pub trait KeyValueTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;

65
src/database/key_value.rs Normal file
View file

@ -0,0 +1,65 @@
use crate::service;
impl service::room::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
}
fn set_room_state(room_id: &RoomId, new_shortstatehash: u64
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
fn set_event_state() -> Result<()> {
db.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
}
fn get_pdu_leaves(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: impl IntoIterator<Item = &'a EventId> + Debug,
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.to_owned();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
}
Ok(())
}
}

View file

@ -15,7 +15,7 @@ pub mod users;
use self::admin::create_admin_room;
use crate::{utils, Config, Error, Result};
use abstraction::DatabaseEngine;
use abstraction::KeyValueDatabaseEngine;
use directories::ProjectDirs;
use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache;
@ -39,8 +39,8 @@ use std::{
use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore};
use tracing::{debug, error, info, warn};
pub struct Database {
_db: Arc<dyn DatabaseEngine>,
pub struct KeyValueDatabase {
_db: Arc<dyn KeyValueDatabaseEngine>,
pub globals: globals::Globals,
pub users: users::Users,
pub uiaa: uiaa::Uiaa,
@ -55,7 +55,7 @@ pub struct Database {
pub pusher: pusher::PushData,
}
impl Database {
impl KeyValueDatabase {
/// Tries to remove the old database but ignores all errors.
pub fn try_remove(server_name: &str) -> Result<()> {
let mut path = ProjectDirs::from("xyz", "koesters", "conduit")
@ -124,7 +124,7 @@ impl Database {
.map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?;
}
let builder: Arc<dyn DatabaseEngine> = match &*config.database_backend {
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
"sqlite" => {
#[cfg(not(feature = "sqlite"))]
return Err(Error::BadConfig("Database backend not found."));
@ -955,7 +955,7 @@ impl Database {
}
/// Sets the emergency password and push rules for the @conduit account in case emergency password is set
fn set_emergency_access(db: &Database) -> Result<bool> {
fn set_emergency_access(db: &KeyValueDatabase) -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name())
.expect("@conduit:server_name is a valid UserId");
@ -979,39 +979,3 @@ fn set_emergency_access(db: &Database) -> Result<bool> {
res
}
pub struct DatabaseGuard(OwnedRwLockReadGuard<Database>);
impl Deref for DatabaseGuard {
type Target = OwnedRwLockReadGuard<Database>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(feature = "conduit_bin")]
#[axum::async_trait]
impl<B> axum::extract::FromRequest<B> for DatabaseGuard
where
B: Send,
{
type Rejection = axum::extract::rejection::ExtensionRejection;
async fn from_request(
req: &mut axum::extract::RequestParts<B>,
) -> Result<Self, Self::Rejection> {
use axum::extract::Extension;
let Extension(db): Extension<Arc<TokioRwLock<Database>>> =
Extension::from_request(req).await?;
Ok(DatabaseGuard(db.read_owned().await))
}
}
impl From<OwnedRwLockReadGuard<Database>> for DatabaseGuard {
fn from(val: OwnedRwLockReadGuard<Database>) -> Self {
Self(val)
}
}

View file

@ -46,27 +46,26 @@ use tikv_jemallocator::Jemalloc;
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;
#[tokio::main]
async fn main() {
let raw_config =
Figment::new()
.merge(
Toml::file(Env::var("CONDUIT_CONFIG").expect(
"The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml",
))
.nested(),
)
.merge(Env::prefixed("CONDUIT_").global());
lazy_static! {
static ref DB: Database = {
let raw_config =
Figment::new()
.merge(
Toml::file(Env::var("CONDUIT_CONFIG").expect(
"The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml",
))
.nested(),
)
.merge(Env::prefixed("CONDUIT_").global());
let config = match raw_config.extract::<Config>() {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e);
std::process::exit(1);
}
};
let config = match raw_config.extract::<Config>() {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e);
std::process::exit(1);
}
};
let start = async {
config.warn_deprecated();
let db = match Database::load_or_create(&config).await {
@ -79,8 +78,15 @@ async fn main() {
std::process::exit(1);
}
};
};
}
run_server(&config, db).await.unwrap();
#[tokio::main]
async fn main() {
lazy_static::initialize(&DB);
let start = async {
run_server(&config).await.unwrap();
};
if config.allow_jaeger {
@ -120,7 +126,8 @@ async fn main() {
}
}
async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<()> {
async fn run_server() -> io::Result<()> {
let config = DB.globals.config;
let addr = SocketAddr::from((config.address, config.port));
let x_requested_with = HeaderName::from_static("x-requested-with");

View file

@ -1,16 +1,24 @@
pub trait Data {
/// Returns the last state hash key added to the db for the given room.
fn get_room_shortstatehash(room_id: &RoomId);
/// Update the current state of the room.
fn set_room_state(room_id: &RoomId, new_shortstatehash: u64
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex
);
/// Associates a state with an event.
fn set_event_state(shorteventid: u64, shortstatehash: u64) -> Result<()> {
/// Returns all events we would send as the prev_events of the next event.
fn get_forward_extremities(room_id: &RoomId) -> Result<HashSet<Arc<EventId>>>;
/// Replace the forward extremities of the room.
fn set_forward_extremities(
room_id: &RoomId,
event_ids: impl IntoIterator<Item = &'_ EventId> + Debug,
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
}
/// Returns the last state hash key added to the db for the given room.
#[tracing::instrument(skip(self))]
pub fn current_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
}
pub struct StateLock;

View file

@ -1,8 +1,13 @@
mod data;
pub use data::Data;
use crate::service::*;
pub struct Service<D: Data> {
db: D,
}
impl Service {
impl Service<_> {
/// Set the room to the given statehash and update caches.
#[tracing::instrument(skip(self, new_state_ids_compressed, db))]
pub fn force_state(
@ -15,11 +20,11 @@ impl Service {
) -> Result<()> {
for event_id in statediffnew.into_iter().filter_map(|new| {
self.parse_compressed_state_event(new)
state_compressor::parse_compressed_state_event(new)
.ok()
.map(|(_, id)| id)
}) {
let pdu = match self.get_pdu_json(&event_id)? {
let pdu = match timeline::get_pdu_json(&event_id)? {
Some(pdu) => pdu,
None => continue,
};
@ -55,56 +60,12 @@ impl Service {
Err(_) => continue,
};
self.update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?;
room::state_cache::update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?;
}
self.update_joined_count(room_id, db)?;
room::state_cache::update_joined_count(room_id, db)?;
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
/// Returns the leaf pdus of a room.
#[tracing::instrument(skip(self))]
pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
/// Replace the leaves of a room.
///
/// The provided `event_ids` become the new leaves, this allows a room to have multiple
/// `prev_events`.
#[tracing::instrument(skip(self))]
pub fn replace_pdu_leaves<'a>(
&self,
room_id: &RoomId,
event_ids: impl IntoIterator<Item = &'a EventId> + Debug,
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.to_owned();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
}
db.set_room_state(room_id, new_shortstatehash);
Ok(())
}
@ -121,11 +82,11 @@ impl Service {
state_ids_compressed: HashSet<CompressedStateEvent>,
globals: &super::globals::Globals,
) -> Result<()> {
let shorteventid = self.get_or_create_shorteventid(event_id, globals)?;
let shorteventid = short::get_or_create_shorteventid(event_id, globals)?;
let previous_shortstatehash = self.current_shortstatehash(room_id)?;
let previous_shortstatehash = db.get_room_shortstatehash(room_id)?;
let state_hash = self.calculate_hash(
let state_hash = super::calculate_hash(
&state_ids_compressed
.iter()
.map(|s| &s[..])
@ -133,11 +94,11 @@ impl Service {
);
let (shortstatehash, already_existed) =
self.get_or_create_shortstatehash(&state_hash, globals)?;
short::get_or_create_shortstatehash(&state_hash, globals)?;
if !already_existed {
let states_parents = previous_shortstatehash
.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
.map_or_else(|| Ok(Vec::new()), |p| room::state_compressor.load_shortstatehash_info(p))?;
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
@ -156,7 +117,7 @@ impl Service {
} else {
(state_ids_compressed, HashSet::new())
};
self.save_state_from_diff(
state_compressor::save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
@ -165,8 +126,7 @@ impl Service {
)?;
}
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
}
@ -183,7 +143,7 @@ impl Service {
) -> Result<u64> {
let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?;
let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?;
if let Some(p) = previous_shortstatehash {
self.shorteventid_shortstatehash
@ -293,4 +253,8 @@ impl Service {
Ok(())
}
pub fn db(&self) -> D {
&self.db
}
}