Draft: SSO login (OAuth 2.0 + OpenID Connect) #1012

Open
avdb13 wants to merge 11 commits from oidc into next
11 changed files with 402 additions and 7 deletions
Showing only changes of commit bc7cf1955b - Show all commits

View file

@ -30,10 +30,12 @@ workspace = true
[dependencies]
# Web framework
axum = { version = "0.6.18", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true }
axum = { version = "0.6.18", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path", "query"], optional = true }
axum-extra = { version = "0.8.0", features = ["cookie"] }
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
tower = { version = "0.4.13", features = ["util"] }
tower-http = { version = "0.4.1", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] }
chrono = "0.4"
# Used for matrix spec type definitions and helpers
#ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
@ -53,6 +55,8 @@ bytes = "1.4.0"
http = "0.2.9"
# Used to find data directory for default db path
directories = "4.0.1"
macaroon = { git = "https://github.com/macaroon-rs/macaroon.git", branch = "main" }
openid = "0.9"
# Used for ruma wrapper
serde_json = { version = "1.0.96", features = ["raw_value"] }
# Used for appservice registration files
@ -108,6 +112,8 @@ clap = { version = "4.3.0", default-features = false, features = ["std", "derive
futures-util = { version = "0.3.28", default-features = false }
# Used for reading the configuration from conduit.toml & environment variables
figment = { version = "0.10.8", features = ["env", "toml"] }
uuid = { version = "0.8", features = ["serde", "v4"] }
time = "0.3.22"
tikv-jemalloc-ctl = { version = "0.5.0", features = ["use_std"], optional = true }
tikv-jemallocator = { version = "0.5.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true }

View file

@ -55,3 +55,10 @@ trusted_servers = ["matrix.org"]
address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy
#address = "0.0.0.0" # If Conduit is running in a container, make sure the reverse proxy (ie. Traefik) can reach it.
[default.openid]
client_id = "conduit"
secret = "00000000-0000-0000-0000-000000000000"
discover_url = "https://keycloak.domain.com/auth/realms/Realm_name"
macaroon_key = "this is the key"
redirect_url = "http://localhost:8081/sso_return"

View file

@ -253,6 +253,10 @@
# Needed for our script for Complement
jq
# Needed for SSO
pkgs.openssl
pkgs.pkg-config
]);
};
});

View file

@ -22,6 +22,7 @@ mod room;
mod search;
mod session;
mod space;
mod sso;
mod state;
mod sync;
mod tag;
@ -57,6 +58,7 @@ pub use room::*;
pub use search::*;
pub use session::*;
pub use space::*;
pub use sso::*;
pub use state::*;
pub use sync::*;
pub use tag::*;

View file

@ -1,15 +1,22 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
use base64::alphabet;
use base64::engine;
use base64::engine::general_purpose;
// use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use macaroon::Verifier;
use ruma::{
api::client::{
error::ErrorKind,
session::{get_login_types, login, logout, logout_all},
uiaa::UserIdentifier,
},
UserId,
events::GlobalAccountDataEventType,
push, UserId,
};
use serde::Deserialize;
use tracing::{info, warn};
use tracing::{debug, error, info, warn};
#[derive(Debug, Deserialize)]
struct Claims {
@ -17,6 +24,56 @@ struct Claims {
//exp: usize,
}
#[tracing::instrument]
fn verifier_callback(v: &macaroon::ByteString) -> bool {
use std::num::ParseIntError;
let result: Result<bool, String> = (|| {
if v.0.starts_with(b"time < ") {
let v2 = std::str::from_utf8(&v.0).map_err(|e| e.to_string())?;
let v3 = v2.trim_start_matches("time < ");
let v4: i64 = v3.parse().map_err(|e: ParseIntError| e.to_string())?;
let now = chrono::Utc::now().timestamp();
if now < v4 {
debug!("macaroon is not expired yet");
Ok(true)
} else {
debug!(
"macaroon expired, v4={} , now={}, v4-now={}",
v4,
now,
v4 - now
);
Ok(false)
}
} else {
Ok(false)
}
})();
match result {
Ok(r) => r,
Err(e) => {
error!("verifier_callback: {:?}", e);
false
}
}
}
#[test]
fn test_verifier_callback() {
use macaroon::ByteString;
let now = chrono::Utc::now().timestamp();
assert!(verifier_callback(&ByteString(
format!("time < {}", now + 10).as_bytes().to_vec()
)));
assert!(!verifier_callback(&ByteString(
format!("time < {}", now - 10).as_bytes().to_vec()
)));
}
/// # `GET /_matrix/client/r0/login`
///
/// Get the supported login types of this server. One of these should be used as the `type` field
@ -27,6 +84,7 @@ pub async fn get_login_types_route(
Ok(get_login_types::v3::Response::new(vec![
get_login_types::v3::LoginType::Password(Default::default()),
get_login_types::v3::LoginType::ApplicationService(Default::default()),
get_login_types::v3::LoginType::Sso(get_login_types::v3::SsoLoginType::default()),
]))
}
@ -94,6 +152,9 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
user_id
}
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
const CUSTOM_ENGINE: engine::GeneralPurpose =
engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD);
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
let token = jsonwebtoken::decode::<Claims>(
token,
@ -105,6 +166,57 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."),
)?
} else if macaroon::Macaroon::deserialize(&CUSTOM_ENGINE.decode(token).unwrap()) // TODO
.is_ok()
{
println!("TOKEN! {}", token);
let macaroon =
macaroon::Macaroon::deserialize(&CUSTOM_ENGINE.decode(token).unwrap()).unwrap();
let v1 = macaroon.identifier();
let user_id = std::str::from_utf8(&v1.0).unwrap();
println!("identifier: {}", user_id);
println!("location: {:?}", macaroon.location());
println!("sig: {:?}", macaroon.signature());
let mut verifier = Verifier::default();
verifier.satisfy_general(verifier_callback);
let openid_client = &services().globals.openid_client;
let (key, _client) = openid_client.as_ref().unwrap();
match verifier.verify(&macaroon, &key, Default::default()) {
Ok(()) => println!("Macaroon verified!"),
Err(error) => println!("Error validating macaroon: {:?}", error),
}
let user_id =
UserId::parse_with_server_name(user_id, services().globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?;
println!("user_id: {}", user_id);
if !services().users.exists(&user_id)? {
let random_password = crate::utils::random_string(TOKEN_LENGTH);
services().users.create(&user_id, Some(&random_password))?;
services().account_data.update(
None,
&user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: push::Ruleset::server_default(&user_id),
},
})
.expect("to json always works"),
)?;
}
user_id
} else {
return Err(Error::BadRequest(
ErrorKind::Unknown,

View file

@ -0,0 +1,222 @@
use axum::extract::Query;
use axum::response::IntoResponse;
use axum::Error;
use axum_extra::extract::cookie::{Cookie, SameSite};
use axum_extra::extract::CookieJar;
use macaroon::Macaroon;
use openid::{Token, Userinfo};
use rand::{thread_rng, Rng};
use reqwest::Url;
use serde::{Deserialize, Serialize};
use crate::{services, Result};
const COOKIE_STATE_EXPIRATION_SECS: i64 = 10 * 60;
const MAC_VALID_SECS: i64 = 10;
const PROOF_KEY_LEN: usize = 32;
#[derive(Deserialize, Serialize)]
struct State {
after_auth: String,
proof_key: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SsoRedirectParams {
pub redirect_url: String,
}
pub async fn get_sso_redirect(
Query(params): Query<SsoRedirectParams>,
cookies: CookieJar,
) -> Result<impl IntoResponse> {
let SsoRedirectParams { redirect_url } = params;
let openid_client = &services().globals.openid_client;
let (_key, client) = openid_client.as_ref().unwrap();
use base64::{
alphabet,
engine::{self, general_purpose},
Engine as _,
};
const CUSTOM_ENGINE: engine::GeneralPurpose =
engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD);
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
let mut arr = [0u8; PROOF_KEY_LEN];
thread_rng().fill(&mut arr[..]);
let proof_key = CUSTOM_ENGINE.encode(arr);
let state = State {
after_auth: redirect_url.to_string(),
proof_key,
};
let state = serde_json::to_string(&state).unwrap();
let state_b64 = CUSTOM_ENGINE.encode(state.as_bytes());
let state_b64_sha256 = ring::digest::digest(&ring::digest::SHA256, &state_b64.as_bytes());
let state_b64_sha256_b64 = CUSTOM_ENGINE.encode(state_b64_sha256);
let cookie1 = Cookie::build("openid-state", state_b64)
.path("/sso_return")
.secure(false) //FIXME
.http_only(true)
.same_site(SameSite::None)
.max_age(time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
.finish();
let updated_jar = cookies.add(cookie1);
// https://docs.rs/openid/0.4.0/openid/struct.Options.html
let auth_url = client.auth_url(&openid::Options {
scope: Some("email".into()), // TODO: openid only?
//TODO: nonce?
state: Some(state_b64_sha256_b64.to_string()),
..Default::default()
});
let redirect = axum::response::Redirect::to(&auth_url.to_string());
Ok((updated_jar, redirect))
}
async fn request_token(
oidc_client: &openid::DiscoveredClient,
code: &str,
) -> Result<Option<(Token, Userinfo)>, Error> {
let mut token: Token = oidc_client.request_token(&code).await.unwrap().into();
if let Some(mut id_token) = token.id_token.as_mut() {
oidc_client.decode_token(&mut id_token).unwrap();
oidc_client.validate_token(&id_token, None, None).unwrap();
// eprintln!("token: {:?}", id_token);
} else {
return Ok(None);
}
let userinfo = oidc_client.request_userinfo(&token).await.unwrap();
// eprintln!("user info: {:?}", userinfo);
Ok(Some((token, userinfo)))
}
// #[derive(Debug)]
// struct User {
// id: String,
// login: Option<String>,
// first_name: Option<String>,
// last_name: Option<String>,
// email: Option<String>,
// image_url: Option<String>,
// activated: bool,
// lang_key: Option<String>,
// authorities: Vec<String>,
// }
// #[derive(Debug, Responder)]
// pub enum ExampleResponse<'a> {
// Redirect(Redirect),
// Unauthorized(rocket::response::status::Unauthorized<&'a str>),
// }
#[derive(Debug, Deserialize)]
pub struct Params {
// pub session_state: String,
pub state: String,
pub code: String,
}
pub async fn get_sso_return(
Query(params): Query<Params>,
cookies: CookieJar,
) -> Result<impl IntoResponse> {
let Params { code, state, .. } = params;
use base64::{
alphabet,
engine::{self, general_purpose},
Engine as _,
};
const CUSTOM_ENGINE: engine::GeneralPurpose =
engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD);
let state = CUSTOM_ENGINE.decode(state).unwrap();
// TODO: test with expired/deleted cookie
let cookie_state = cookies.get("openid-state").unwrap();
let cookie_state_b64_sha256 =
ring::digest::digest(&ring::digest::SHA256, &cookie_state.value().as_bytes());
if state != cookie_state_b64_sha256.as_ref() {
// return ExampleResponse::Unauthorized(rocket::response::status::Unauthorized(Some(
// "invalid state",
// )));
panic!("invalid state");
}
let decoded_state = CUSTOM_ENGINE.decode(cookie_state.value()).unwrap();
let decoded_state: State = serde_json::from_slice(&decoded_state).unwrap();
let openid_client = &services().globals.openid_client;
let (key, client) = openid_client.as_ref().unwrap();
let username;
match request_token(client, &code).await {
Ok(Some((_token, userinfo))) => {
/*
let id = uuid::Uuid::new_v4().to_string();
let login = userinfo.preferred_username.clone();
let email = userinfo.email.clone();
h
let new_user = User {
id: userinfo.sub.clone().unwrap_or_default(),
login,
last_name: userinfo.family_name.clone(),
first_name: userinfo.name.clone(),
email,
activated: userinfo.email_verified,
image_url: userinfo.picture.clone().map(|x| x.to_string()),
lang_key: Some("en".to_string()),
authorities: vec!["ROLE_USER".to_string()], //FIXME: read from token
};
*/
// user = new_user.login.unwrap();
username = userinfo.preferred_username.unwrap();
}
Ok(None) => {
// return ExampleResponse::Unauthorized(rocket::response::status::Unauthorized(Some(
// "no id_token found",
// )));
panic!("no id_token found");
}
Err(err) => {
eprintln!("login error in call: {:?}", err);
// return ExampleResponse::Unauthorized(rocket::response::status::Unauthorized(Some(
// "login error in call",
// )));
panic!("login error in call");
}
}
// Create our macaroon
let mut macaroon = match Macaroon::create(Some("location".into()), &key, username.into()) {
Ok(macaroon) => macaroon,
Err(error) => panic!("Error creating macaroon: {:?}", error),
};
let something = format!("time < {}", chrono::Utc::now().timestamp() + MAC_VALID_SECS).into();
macaroon.add_first_party_caveat(something);
let serialized = macaroon.serialize(macaroon::Format::V2).unwrap();
let encoded = CUSTOM_ENGINE.encode(serialized);
let redirect_url =
Url::parse_with_params(&decoded_state.after_auth, &[("loginToken", encoded)]).unwrap();
let redirect = axum::response::Redirect::to(&redirect_url.to_string());
Ok(redirect)
}

View file

@ -4,6 +4,7 @@ use std::{
net::{IpAddr, Ipv4Addr},
};
use reqwest::Url;
use ruma::{OwnedServerName, RoomVersionId};
use serde::{de::IgnoredAny, Deserialize};
use tracing::warn;
@ -79,6 +80,7 @@ pub struct Config {
pub turn_secret: String,
#[serde(default = "default_turn_ttl")]
pub turn_ttl: u64,
pub openid: Option<OpenIdConfig>,
pub emergency_password: Option<String>,
@ -92,6 +94,15 @@ pub struct TlsConfig {
pub key: String,
}
#[derive(Clone, Debug, Deserialize)]
pub struct OpenIdConfig {
pub client_id: String,
pub secret: String,
pub discover_url: Url,
pub macaroon_key: String,
pub redirect_url: String,
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config {

View file

@ -401,7 +401,7 @@ impl KeyValueDatabase {
let db = Box::leak(db_raw);
let services_raw = Box::new(Services::build(db, config)?);
let services_raw = Box::new(Services::build(db, config).await?);
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));

View file

@ -148,6 +148,7 @@ async fn run_server() -> io::Result<()> {
let x_requested_with = HeaderName::from_static("x-requested-with");
let middlewares = ServiceBuilder::new()
// TODO token
.sensitive_headers([header::AUTHORIZATION])
.layer(axum::middleware::from_fn(spawn_task))
.layer(
@ -399,6 +400,11 @@ fn routes() -> Router {
"/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route),
)
.route(
"/_matrix/client/v3/login/sso/redirect",
get(client_server::get_sso_redirect),
)
.route("/sso_return", get(client_server::get_sso_return))
.ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route)
.ruma_route(server_server::send_transaction_message_route)

View file

@ -75,6 +75,8 @@ pub struct Service {
pub rotate: RotationHandler,
pub shutdown: AtomicBool,
pub openid_client: Option<(macaroon::MacaroonKey, openid::DiscoveredClient)>,
}
/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
@ -147,7 +149,7 @@ impl Resolve for Resolver {
}
impl Service {
pub fn load(db: &'static dyn Data, config: Config) -> Result<Self> {
pub async fn load(db: &'static dyn Data, config: Config) -> Result<Self> {
let keypair = db.load_keypair();
let keypair = match keypair {
@ -182,6 +184,28 @@ impl Service {
// Experimental, partially supported room versions
let unstable_room_versions = vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5];
let openid_client = match config.openid.as_ref() {
Some(openid) => {
let mut key_bytes: [u8; 32] = [0; 32];
key_bytes.copy_from_slice(&base64::decode(&openid.macaroon_key).unwrap());
let secret_key: macaroon::MacaroonKey = key_bytes.into();
let r = (
secret_key,
openid::DiscoveredClient::discover(
openid.client_id.to_owned(),
openid.secret.to_owned(),
Some(openid.redirect_url.to_owned()),
openid.discover_url.to_owned(),
)
.await
.unwrap(),
);
Some(r)
}
None => None,
};
let mut s = Self {
db,
config,
@ -212,6 +236,7 @@ impl Service {
sync_receivers: RwLock::new(HashMap::new()),
rotate: RotationHandler::new(),
shutdown: AtomicBool::new(false),
openid_client,
};
fs::create_dir_all(s.get_media_folder())?;

View file

@ -37,7 +37,7 @@ pub struct Services {
}
impl Services {
pub fn build<
pub async fn build<
D: appservice::Data
+ pusher::Data
+ rooms::Data
@ -115,7 +115,7 @@ impl Services {
media: media::Service { db },
sending: sending::Service::build(db, &config),
globals: globals::Service::load(db, config)?,
globals: globals::Service::load(db, config).await?,
})
}
fn memory_usage(&self) -> String {