Fetch server keys concurrently #985

Open
_ZN3val wants to merge 2 commits from concurrent-fetch-keys into next
6 changed files with 165 additions and 53 deletions

View file

@ -1033,6 +1033,11 @@ async fn join_room_by_id_helper(
drop(state_lock);
let pub_key_map = RwLock::new(BTreeMap::new());
services()
.rooms
.event_handler
.fetch_required_signing_keys([&signed_value], &pub_key_map)
.await?;
services()
.rooms
.event_handler
@ -1259,6 +1264,12 @@ pub(crate) async fn invite_helper<'a>(
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?;
services()
.rooms
.event_handler
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
let pdu_id: Vec<u8> = services()
.rooms
.event_handler

View file

@ -210,7 +210,10 @@ where
let keys_result = services()
.rooms
.event_handler
.fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_owned()])
.fetch_signing_keys_for_server(
&x_matrix.origin,
vec![x_matrix.key.to_owned()],
)
.await;
let keys = match keys_result {

View file

@ -706,6 +706,7 @@ pub async fn send_transaction_message_route(
// events that it references.
// let mut auth_cache = EventMap::new();
let mut parsed_pdus = vec![];
for pdu in &body.pdus {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
@ -733,8 +734,28 @@ pub async fn send_transaction_message_route(
continue;
}
};
parsed_pdus.push((event_id, value, room_id));
// We do not add the event_id field to the pdu here because of signature and hashes checks
}
// We go through all the signatures we see on the PDUs and fetch the corresponding
// signing keys
services()
.rooms
.event_handler
.fetch_required_signing_keys(
parsed_pdus.iter().map(|(_event_id, event, _room_id)| event),
&pub_key_map,
)
.await
.unwrap_or_else(|e| {
warn!(
"Could not fetch all signatures for PDUs from {}: {:?}",
sender_servername, e
)
});
for (event_id, value, room_id) in parsed_pdus {
let mutex = Arc::clone(
services()
.globals
@ -1574,6 +1595,12 @@ async fn create_join_event(
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?;
services()
.rooms
.event_handler
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
let mutex = Arc::clone(
services()
.globals

View file

@ -803,7 +803,7 @@ impl Service {
services()
.rooms
.event_handler
.fetch_required_signing_keys(&value, &pub_key_map)
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
let pub_key_map = pub_key_map.read().unwrap();

View file

@ -288,16 +288,11 @@ impl Service {
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>> {
Box::pin(async move {
// 1.1. Remove unsigned field
// 1. Remove unsigned field
value.remove("unsigned");
// TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json
// We go through all the signatures we see on the value and fetch the corresponding signing
// keys
self.fetch_required_signing_keys(&value, pub_key_map)
.await?;
// 2. Check signatures, otherwise drop
// 3. check content hash, redact if doesn't match
let create_event_content: RoomCreateEventContent =
@ -1034,14 +1029,14 @@ impl Service {
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
let mut pdus = vec![];
let mut events_with_auth_events = vec![];
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
events_with_auth_events.push((id, Some(local_pdu), vec![]));
continue;
}
@ -1140,7 +1135,36 @@ impl Service {
}
}
}
events_with_auth_events.push((id, None, events_in_reverse_order))
}
// We go through all the signatures we see on the PDUs and their unresolved
// dependencies and fetch the corresponding signing keys
info!("fetch_required_signing_keys for {}", origin);
self.fetch_required_signing_keys(
events_with_auth_events
.iter()
.flat_map(|(_id, _local_pdu, events)| events)
.map(|(_event_id, event)| event),
pub_key_map,
)
.await
.unwrap_or_else(|e| {
warn!(
"Could not fetch all signatures for PDUs from {}: {:?}",
origin, e
)
});
let mut pdus = vec![];
for (id, local_pdu, events_in_reverse_order) in events_with_auth_events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Some(local_pdu) = local_pdu {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
}
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services()
.globals
@ -1291,53 +1315,94 @@ impl Service {
}
#[tracing::instrument(skip_all)]
pub(crate) async fn fetch_required_signing_keys(
&self,
event: &BTreeMap<String, CanonicalJsonValue>,
pub(crate) async fn fetch_required_signing_keys<'a, E>(
&'a self,
events: E,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
let signatures = event
.get("signatures")
.ok_or(Error::BadServerResponse(
"No signatures in server response pdu.",
))?
.as_object()
.ok_or(Error::BadServerResponse(
"Invalid signatures object in server response pdu.",
))?;
) -> Result<()>
where
E: IntoIterator<Item = &'a BTreeMap<String, CanonicalJsonValue>>,
{
let mut server_key_ids = HashMap::new();
// We go through all the signatures we see on the value and fetch the corresponding signing
// keys
for (signature_server, signature) in signatures {
let signature_object = signature.as_object().ok_or(Error::BadServerResponse(
"Invalid signatures content object in server response pdu.",
))?;
for event in events.into_iter() {
for (signature_server, signature) in event
.get("signatures")
.ok_or(Error::BadServerResponse(
"No signatures in server response pdu.",
))?
.as_object()
.ok_or(Error::BadServerResponse(
"Invalid signatures object in server response pdu.",
))?
{
let signature_object = signature.as_object().ok_or(Error::BadServerResponse(
"Invalid signatures content object in server response pdu.",
))?;
let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>();
let fetch_res = self
.fetch_signing_keys(
signature_server.as_str().try_into().map_err(|_| {
Error::BadServerResponse(
"Invalid servername in signatures of server response pdu.",
)
})?,
signature_ids,
)
.await;
let keys = match fetch_res {
Ok(keys) => keys,
Err(_) => {
warn!("Signature verification failed: Could not fetch signing key.",);
continue;
for signature_id in signature_object.keys() {
server_key_ids
.entry(signature_server.clone())
.or_insert_with(HashSet::new)
.insert(signature_id.clone());
}
};
}
}
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(signature_server.clone(), keys);
if server_key_ids.is_empty() {
// Nothing to do, can exit early
return Ok(());
}
info!(
"Fetch keys for {}",
server_key_ids
.keys()
.cloned()
.collect::<Vec<_>>()
.join(", ")
);
let mut server_keys: FuturesUnordered<_> = server_key_ids
.into_iter()
.map(|(signature_server, signature_ids)| async {
let signature_server2 = signature_server.clone();
let fetch_res = self
.fetch_signing_keys_for_server(
signature_server2.as_str().try_into().map_err(|_| {
(
signature_server.clone(),
Error::BadServerResponse(
"Invalid servername in signatures of server response pdu.",
),
)
})?,
signature_ids.into_iter().collect(), // HashSet to Vec
)
.await;
match fetch_res {
Ok(keys) => Ok((signature_server, keys)),
Err(e) => {
warn!("Signature verification failed: Could not fetch signing key.",);
Err((signature_server, e))
}
}
})
.collect();
while let Some(fetch_res) = server_keys.next().await {
match fetch_res {
Ok((signature_server, keys)) => {
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(signature_server.clone(), keys);
}
Err((signature_server, e)) => {
warn!("Failed to fetch keys for {}: {:?}", signature_server, e);
}
}
}
Ok(())
@ -1596,7 +1661,7 @@ impl Service {
/// Search the DB for the signing keys of the given server, if we don't have them
/// fetch them from the server and save to our DB.
#[tracing::instrument(skip_all)]
pub async fn fetch_signing_keys(
pub async fn fetch_signing_keys_for_server(
&self,
origin: &ServerName,
signature_ids: Vec<String>,

View file

@ -1139,6 +1139,12 @@ impl Service {
return Ok(());
}
services()
.rooms
.event_handler
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
services()
.rooms
.event_handler