diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 8294b85b..2da88e6c 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1974,7 +1974,13 @@ pub async fn claim_keys_route( #[cfg(test)] mod tests { + use super::linearize_previous_events; use super::{add_port_to_hostname, get_ip_with_port, FedDest}; + use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId}; + use serde::{Deserialize, Serialize}; + use serde_json::value::RawValue; + use serde_json::Value; + use std::collections::HashMap; #[test] fn ips_get_default_ports() { @@ -2015,4 +2021,227 @@ mod tests { FedDest::Named(String::from("example.com"), String::from(":1337")) ) } + + type PduStorage = HashMap; + + #[derive(Debug, Serialize, Deserialize)] + struct MockPDU { + content: i32, + prev_events: Vec, + } + + fn mock_event_id(id: &i32) -> OwnedEventId { + const DOMAIN: &str = "canterlot.eq"; + ::try_from(format!("${id}:{DOMAIN}")).unwrap() + } + + fn create_graph(data: Vec<(i32, Vec)>) -> PduStorage { + data.iter() + .map(|(head, tail)| { + let key = mock_event_id(head); + let pdu = MockPDU { + content: *head, + prev_events: tail.iter().map(mock_event_id).collect(), + }; + let value = serde_json::to_value(pdu).unwrap(); + let value: CanonicalJsonValue = value.try_into().unwrap(); + (key, value.as_object().unwrap().to_owned()) + }) + .collect() + } + + fn mock_full_graph() -> PduStorage { + /* + (1) + __________|___________ + / / \ \ + (2) (3) (10) (11) + / \ / \ | | + (4) (5) (6) (7) (12) (13) + | | | + (8) (9) (14) + \ / + (15) + | + (16) + */ + create_graph(vec![ + (1, vec![2, 3, 10, 11]), + (2, vec![4, 5]), + (3, vec![6, 7]), + (4, vec![]), + (5, vec![8]), + (6, vec![9]), + (7, vec![]), + (8, vec![15]), + (9, vec![15]), + (10, vec![12]), + (11, vec![13]), + (12, vec![]), + (13, vec![14]), + (14, vec![]), + (15, vec![16]), + (16, vec![16]), + ]) + } + + fn extract_events_payload(events: Vec>) -> Vec { + events + .iter() + .map(|e| serde_json::from_str(e.get()).unwrap()) + .map(|p: MockPDU| p.content) + .collect() + } + + #[test] + fn backfill_empty() { + let events = linearize_previous_events( + vec![], + vec![], + 16u64, + |_| unreachable!(), + |_| true, + |_| true, + ); + assert!(events.is_empty()); + } + #[test] + fn backfill_limit() { + /* + (5) → (4) → (3) → (2) → (1) → × + */ + let events = create_graph(vec![ + (1, vec![]), + (2, vec![1]), + (3, vec![2]), + (4, vec![3]), + (5, vec![4]), + ]); + let roots = vec![mock_event_id(&5)]; + let result = linearize_previous_events( + roots, + vec![], + 3u64, + |e| events.get(e).cloned(), + |_| true, + |_| true, + ); + + assert_eq!(extract_events_payload(result), vec![5, 4, 3]) + } + + #[test] + fn backfill_bfs() { + let events = mock_full_graph(); + let roots = vec![mock_event_id(&1)]; + let result = linearize_previous_events( + roots, + vec![], + 100u64, + |e| events.get(e).cloned(), + |_| true, + |_| true, + ); + assert_eq!( + extract_events_payload(result), + vec![1, 2, 3, 10, 11, 4, 5, 6, 7, 12, 13, 8, 9, 14, 15, 16] + ) + } + + #[test] + fn backfill_subgraph() { + let events = mock_full_graph(); + let roots = vec![mock_event_id(&3)]; + let result = linearize_previous_events( + roots, + vec![], + 100u64, + |e| events.get(e).cloned(), + |_| true, + |_| true, + ); + assert_eq!(extract_events_payload(result), vec![3, 6, 7, 9, 15, 16]) + } + + #[test] + fn backfill_two_roots() { + let events = mock_full_graph(); + let roots = vec![mock_event_id(&3), mock_event_id(&11)]; + let result = linearize_previous_events( + roots, + vec![], + 100u64, + |e| events.get(e).cloned(), + |_| true, + |_| true, + ); + assert_eq!( + extract_events_payload(result), + vec![3, 11, 6, 7, 13, 9, 14, 15, 16] + ) + } + + #[test] + fn backfill_exclude_events() { + let events = mock_full_graph(); + let roots = vec![mock_event_id(&1)]; + let excluded_events = vec![ + mock_event_id(&14), + mock_event_id(&15), + mock_event_id(&16), + mock_event_id(&3), + ]; + let result = linearize_previous_events( + roots, + excluded_events, + 100u64, + |e| events.get(e).cloned(), + |_| true, + |_| true, + ); + assert_eq!( + extract_events_payload(result), + vec![1, 2, 10, 11, 4, 5, 6, 7, 12, 13, 8, 9] + ) + } + + #[test] + fn backfill_exclude_branch_with_evil_event() { + let events = mock_full_graph(); + let roots = vec![mock_event_id(&1)]; + let result = linearize_previous_events( + roots, + vec![], + 100u64, + |e| events.get(e).cloned(), + |_| true, + |e| { + let value: Value = CanonicalJsonValue::Object(e.clone()).into(); + let pdu: MockPDU = serde_json::from_value(value).unwrap(); + pdu.content != 3 + }, + ); + assert_eq!( + extract_events_payload(result), + vec![1, 2, 10, 11, 4, 5, 12, 13, 8, 14, 15, 16] + ) + } + + #[test] + fn backfill_exclude_branch_with_inaccessible_event() { + let events = mock_full_graph(); + let roots = vec![mock_event_id(&1)]; + let result = linearize_previous_events( + roots, + vec![], + 100u64, + |e| events.get(e).cloned(), + |e| e != mock_event_id(&3), + |_| true, + ); + assert_eq!( + extract_events_payload(result), + vec![1, 2, 10, 11, 4, 5, 12, 13, 8, 14, 15, 16] + ) + } }