From 8fd6a88f43bc1d9f0070457b287570ce0c45c16f Mon Sep 17 00:00:00 2001 From: mikoto Date: Thu, 22 Feb 2024 18:12:03 +0000 Subject: [PATCH] initial draft --- Cargo.toml | 4 +- askama.toml | 2 + conduit-example.toml | 20 +- public/conduit.svg | 6 + public/templates/base.html | 152 ++++++++++++ public/templates/sso_footer.html | 19 ++ public/templates/sso_login_idp_picker.html | 58 +++++ public/templates/terms.html | 27 +++ src/api/client_server/mod.rs | 4 +- src/api/client_server/oidc.rs | 114 +++++++++ src/api/client_server/session.rs | 28 ++- src/api/client_server/sso.rs | 254 --------------------- src/config/mod.rs | 2 + src/config/oidc.rs | 142 +++++++++--- src/main.rs | 12 +- src/service/globals/mod.rs | 31 --- src/service/mod.rs | 3 + src/service/oidc/mod.rs | 252 ++++++++++++++++++++ src/utils/mod.rs | 29 ++- 19 files changed, 809 insertions(+), 350 deletions(-) create mode 100644 askama.toml create mode 100644 public/conduit.svg create mode 100644 public/templates/base.html create mode 100644 public/templates/sso_footer.html create mode 100644 public/templates/sso_login_idp_picker.html create mode 100644 public/templates/terms.html create mode 100644 src/api/client_server/oidc.rs delete mode 100644 src/api/client_server/sso.rs create mode 100644 src/service/oidc/mod.rs diff --git a/Cargo.toml b/Cargo.toml index e5d609b3..29db8795 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,7 +56,6 @@ http = "0.2.9" # Used to find data directory for default db path directories = "4.0.1" macaroon = { git = "https://github.com/macaroon-rs/macaroon.git", branch = "main" } -openid = "0.9" # Used for ruma wrapper serde_json = { version = "1.0.96", features = ["raw_value"] } # Used for appservice registration files @@ -122,6 +121,9 @@ async-trait = "0.1.68" sd-notify = { version = "0.4.1", optional = true } url = { version = "2.5.0", features = ["serde"] } +openidconnect = { version = "3.5.0", features = ["jwk-alg", "accept-string-booleans"] } +askama = { version = "0.12.1", features = ["with-axum"] } +askama_axum = { version = "0.4.0", features = ["urlencode", "config"] } [target.'cfg(unix)'.dependencies] nix = { version = "0.26.2", features = ["resource"] } diff --git a/askama.toml b/askama.toml new file mode 100644 index 00000000..893df5be --- /dev/null +++ b/askama.toml @@ -0,0 +1,2 @@ +[general] +dirs = ["public/templates"] diff --git a/conduit-example.toml b/conduit-example.toml index 86827c87..40cacd06 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -58,11 +58,15 @@ address = "127.0.0.1" # This makes sure Conduit can only be reached using the re macaroon_key = "this is the key" -[[oidc_provider]] -id = "keycloak" -name = "keycloak" -client.id = "conduit" -client.secret = "00000000-0000-0000-0000-000000000000" -discover_url = "https://keycloak.domain.com/auth/realms/example" -scopes = ["openid", "read_user"] -backchannel_logout = true +[[global.oidc]] +idp_id = "gitlab" +idp_name = "Gitlab" +idp_icon = "mxc://matrix.org/00000000000000000000000000000000" + +issuer = "https://gitlab.com" +scopes = ["openid", "profile"] + +[global.oidc.client] +id = "0000000000000000000000000000000000000000000000000000000000000000" +secret = "0000000000000000000000000000000000000000000000000000000000000000" +auth_method = "post" diff --git a/public/conduit.svg b/public/conduit.svg new file mode 100644 index 00000000..aa7a352c --- /dev/null +++ b/public/conduit.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/public/templates/base.html b/public/templates/base.html new file mode 100644 index 00000000..e08cb922 --- /dev/null +++ b/public/templates/base.html @@ -0,0 +1,152 @@ + + + + + + + {% block title %}{% endblock %} + + + + +
+ [matrix] +
+ +{% block body %}{% endblock %} + + + diff --git a/public/templates/sso_footer.html b/public/templates/sso_footer.html new file mode 100644 index 00000000..fdcb206c --- /dev/null +++ b/public/templates/sso_footer.html @@ -0,0 +1,19 @@ + diff --git a/public/templates/sso_login_idp_picker.html b/public/templates/sso_login_idp_picker.html new file mode 100644 index 00000000..8bab0bb3 --- /dev/null +++ b/public/templates/sso_login_idp_picker.html @@ -0,0 +1,58 @@ +{% extends "base.html" %} +{% block title %}Choose identity provider{% endblock %} + +{% block style %} + .providers { + list-style: none; + padding: 0; + } + + .providers li { + margin: 12px; + } + + .providers a { + display: block; + border-radius: 4px; + border: 1px solid #17191C; + padding: 8px; + text-align: center; + text-decoration: none; + color: #17191C; + display: flex; + align-items: center; + font-weight: bold; + } + + .providers a img { + width: 24px; + height: 24px; + } + .providers a span { + flex: 1; + } +{% endblock %} + +{% block body %} +
+

Log in to {{ server_name }}

+

Choose an identity provider to log in

+
+
+ +
+{% include "sso_footer.html" %} +{% endblock %} diff --git a/public/templates/terms.html b/public/templates/terms.html new file mode 100644 index 00000000..66c40a70 --- /dev/null +++ b/public/templates/terms.html @@ -0,0 +1,27 @@ +{% extends "_base.html" %} +{% block title %}Authentication{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} +
+
+ {% if error is defined %} +

Error: {{ error }}

+ {% endif %} +

+ Please click the button below if you agree to the + privacy policy of this homeserver. +

+ + +
+
+{% endblock %} diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index f244de83..479c760c 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -11,6 +11,7 @@ mod keys; mod media; mod membership; mod message; +mod oidc; mod presence; mod profile; mod push; @@ -22,7 +23,6 @@ mod room; mod search; mod session; mod space; -mod sso; mod state; mod sync; mod tag; @@ -47,6 +47,7 @@ pub use keys::*; pub use media::*; pub use membership::*; pub use message::*; +pub use oidc::*; pub use presence::*; pub use profile::*; pub use push::*; @@ -58,7 +59,6 @@ pub use room::*; pub use search::*; pub use session::*; pub use space::*; -pub use sso::*; pub use state::*; pub use sync::*; pub use tag::*; diff --git a/src/api/client_server/oidc.rs b/src/api/client_server/oidc.rs new file mode 100644 index 00000000..fc61eec2 --- /dev/null +++ b/src/api/client_server/oidc.rs @@ -0,0 +1,114 @@ +use crate::{ + config::Metadata, service::oidc::COOKIE_STATE_EXPIRATION_SECS, services, Error, Result, Ruma, +}; +use askama::Template; +use axum::response::IntoResponse; +use axum_extra::extract::cookie::{Cookie, SameSite}; +use bytes::BufMut; +use http::{header::COOKIE, HeaderValue, StatusCode}; +use ruma::api::{ + client::{error::ErrorKind, session}, + error::IntoHttpError, + OutgoingResponse, +}; + +// const SEED_LEN: usize = 32; + +/// # `GET /_matrix/client/v3/login/sso/redirect` +/// +/// Redirect user to SSO interface. +/// +pub async fn get_sso_redirect( + body: Ruma, +) -> axum::response::Response { + let server_name = services().globals.server_name().to_string(); + let metadata = services().oidc.get_metadata(); + let redirect_url = body.redirect_url.clone(); + + let t = SsoTemplate { + server_name, + metadata, + redirect_url, + }; + + match t.render() { + Ok(body) => { + let headers = [( + http::header::CONTENT_TYPE, + http::HeaderValue::from_static(SsoTemplate::MIME_TYPE), + )]; + (headers, body).into_response() + } + Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "woops").into_response(), + } +} + +pub struct SsoResponse { + pub inner: session::sso_login_with_provider::v3::Response, + pub cookie: String, +} + +impl OutgoingResponse for SsoResponse { + fn try_into_http_response( + self, + ) -> Result, IntoHttpError> { + self.inner.try_into_http_response().map(|mut ok| { + *ok.status_mut() = StatusCode::FOUND; + + match HeaderValue::from_str(self.cookie.as_str()) { + Ok(value) => { + ok.headers_mut().insert(COOKIE, value); + + Ok(ok) + } + Err(e) => Err(IntoHttpError::Header(e)), + } + })? + } +} + +/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}` +/// +/// Redirect user to SSO interface. +pub async fn get_sso_redirect_with_idp_id( + body: Ruma, + // State(uiaa_session): State>, +) -> Result { + // if services().oidc.get_all().len() == 1 { + // } + + let Ok(provider) = services().oidc.get_provider(&body.idp_id).await else { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Unknown identity provider", + )); + }; + + let (location, cookie) = provider.handle_redirect(&body.redirect_url).await; + let inner = session::sso_login_with_provider::v3::Response { + location: location.to_string(), + }; + + let cookie = Cookie::build("openid-state", cookie) + .path("/_conduit/client/oidc") + .secure(false) //FIXME + // .secure(true) + .http_only(true) + .same_site(SameSite::None) + .max_age(time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS)) + .finish() + .to_string(); + + Ok(SsoResponse { inner, cookie }) + // Ok((axum::http::StatusCode::FOUND, [(LOCATION, &body.redirect_url)]).into_response()) +} + +pub async fn get_sso_return() {} + +#[derive(Template)] +#[template(path = "sso_login_idp_picker.html", escape = "none")] +struct SsoTemplate { + pub server_name: String, + pub metadata: Vec, + pub redirect_url: String, +} diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 028d6ce0..c369a62c 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,8 +1,6 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{services, utils, Error, Result, Ruma}; -use base64::alphabet; -use base64::engine; -use base64::engine::general_purpose; +use base64::{alphabet, engine, engine::general_purpose}; // use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use macaroon::Verifier; @@ -81,10 +79,20 @@ fn test_verifier_callback() { pub async fn get_login_types_route( _body: Ruma, ) -> Result { + let identity_providers = services() + .oidc + .get_metadata() + .clone() + .into_iter() + .map(Into::into) + .collect(); + Ok(get_login_types::v3::Response::new(vec![ get_login_types::v3::LoginType::Password(Default::default()), get_login_types::v3::LoginType::ApplicationService(Default::default()), - get_login_types::v3::LoginType::Sso(get_login_types::v3::SsoLoginType::default()), + get_login_types::v3::LoginType::Sso(get_login_types::v3::SsoLoginType { + identity_providers, + }), ])) } @@ -184,13 +192,13 @@ pub async fn login_route(body: Ruma) -> Result println!("Macaroon verified!"), - Err(error) => println!("Error validating macaroon: {:?}", error), - } + // match verifier.verify(&macaroon, &key, Default::default()) { + // Ok(()) => println!("Macaroon verified!"), + // Err(error) => println!("Error validating macaroon: {:?}", error), + // } let user_id = UserId::parse_with_server_name(user_id, services().globals.server_name()) diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs deleted file mode 100644 index 70d1069e..00000000 --- a/src/api/client_server/sso.rs +++ /dev/null @@ -1,254 +0,0 @@ -use axum::extract::Query; -use axum::response::IntoResponse; -use axum_extra::extract::cookie::{Cookie, SameSite}; -use axum_extra::extract::CookieJar; -use macaroon::Macaroon; -use openid::{Token, Userinfo, Provider}; -use rand::{thread_rng, Rng}; -use reqwest::Url; -use ring::digest; -use ruma::api::client::error::ErrorKind; -use serde::{Deserialize, Serialize}; - -use crate::{services, Result, Error}; - -const COOKIE_STATE_EXPIRATION_SECS: i64 = 10 * 60; -const MAC_VALID_SECS: i64 = 10; -const PROOF_KEY_LEN: usize = 32; -const NONCE_LEN: usize = 32; - -#[derive(Deserialize, Serialize)] -struct State { - after_auth: String, - proof_key: String, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct SsoRedirectParams { - pub redirect_url: String, -} - -/// # `GET /_matrix/client/v3/login/sso/redirect` -/// -/// Redirect user to SSO interface. -/// -pub async fn get_sso_redirect( - Query(params): Query, - // State(uia_session): State>, - cookies: CookieJar, -) -> Result { - let SsoRedirectParams { redirect_url } = params; - - let client = services().globals.oidc.as_ref().unwrap(); - let key = services().globals.macaroon.as_ref() -.ok_or(Error::BadConfig(&"Missing macaroon key in config file."))?; - - use base64::{ - alphabet, - engine::{self, general_purpose}, - Engine as _, - }; - - const CUSTOM_ENGINE: engine::GeneralPurpose = - engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD); - - // https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 - let mut arr = [0u8; PROOF_KEY_LEN]; - - thread_rng().fill(&mut arr[..]); - let proof_key = CUSTOM_ENGINE.encode(arr); - - thread_rng().fill(&mut arr[..]); - let nonce = CUSTOM_ENGINE.encode(arr); - - let state = State { - after_auth: redirect_url.to_string(), - proof_key, - }; - - let state = serde_json::to_string(&state).unwrap(); - - let key = macaroon::MacaroonKey::generate(key.as_ref()); - let mut macaroon = macaroon::Macaroon::create(None, &key, "key".into()).unwrap(); - let issuer = client.provider.auth_uri(); - - let expires = chrono::Utc::now() + chrono::TimeDelta::seconds(COOKIE_STATE_EXPIRATION_SECS); - - macaroon.add_first_party_caveat(format!("state = {state}").into()); - macaroon.add_first_party_caveat(format!("provider = ???", ).into()); - macaroon.add_first_party_caveat(format!("state = {state}").into()); - macaroon.add_first_party_caveat(format!("nonce = {nonce}").into()); - macaroon.add_first_party_caveat(format!("redirect_url = {redirect_url}").into()); - - let cookie1 = Cookie::build("openid-state", state_b64) - .path("/sso_return") - .secure(false) //FIXME - // .secure(true) - .http_only(true) - .same_site(SameSite::None) - .max_age(time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS)) - .finish(); - let updated_jar = cookies.add(cookie1); - - let cookie2 = Cookie::build("openid-state-no-samesite", state_b64) - .path("/sso_return") - .http_only(true) - .max_age(time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS)) - .finish(); - let updated_jar = cookies.add(cookie1).add(cookie2); - - - let auth_url = client.auth_url(&openid::Options { - scope: Some("email".into()), - state: Some(state_b64_sha256_b64.to_string()), - ..Default::default() - }); - - let redirect = axum::response::Redirect::to(auth_url.as_ref()); - Ok((updated_jar, redirect)) -} - -async fn request_token( - oidc_client: &openid::DiscoveredClient, - code: &str, -) -> Result<(Token, Userinfo), Error> { - let mut token: Token = oidc_client.request_token(code).await - .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "OICD token request failed."))? - .into(); - - let Some(ref mut id_token) = token.id_token else { - return Err(Error::BadServerResponse("OICD token did not contain id_token"))?; - }; - - oidc_client.decode_token(id_token) - .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Couldn't decode token."))?; - oidc_client.validate_token(id_token, None, None) - .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Couldn't validate token."))?; - - let userinfo = oidc_client.request_userinfo(&token).await - .map_err(|_| Error::BadServerResponse("Requesting userinfo failed."))?; - - Ok((token, userinfo)) -} - -// #[derive(Debug)] -// struct User { -// id: String, -// login: Option() -// first_name: Option, -// last_name: Option, -// email: Option, -// image_url: Option, -// activated: bool, -// lang_key: Option, -// authorities: Vec, -// } - -// #[derive(Debug, Responder)] -// pub enum ExampleResponse<'a> { -// Redirect(Redirect), -// Unauthorized(rocket::response::status::Unauthorized<&'a str>), -// } - -#[derive(Debug, Deserialize)] -pub struct Params { - // pub session_state: String, - pub state: String, - pub code: String, -} - -pub async fn get_sso_return( - Query(params): Query, - cookies: CookieJar, -) -> Result { - let Params { code, state, .. } = params; - - use base64::{ - alphabet, - engine::{self, general_purpose}, - Engine as _, - }; - - const CUSTOM_ENGINE: engine::GeneralPurpose = - engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD); - - let state = CUSTOM_ENGINE.decode(state).unwrap(); - - // TODO: test with expired/deleted cookie - let cookie_state = cookies.get("openid-state").unwrap(); - let cookie_state_b64_sha256 = - digest::digest(&digest::SHA256, &cookie_state.value().as_bytes()); - - if state != cookie_state_b64_sha256.as_ref() { - // return ExampleResponse::Unauthorized(rocket::response::status::Unauthorized(Some( - // "invalid state", - // ))); - panic!("invalid state"); - } - - let decoded_state = CUSTOM_ENGINE.decode(cookie_state.value()).unwrap(); - let decoded_state: State = serde_json::from_slice(&decoded_state).unwrap(); - - let openid_client = &services().globals.openid_client; - let (key, client) = openid_client.as_ref().unwrap(); - - let username; - match request_token(client, &code).await { - Ok(Some((_token, userinfo))) => { - /* - let id = uuid::Uuid::new_v4().to_string(); - - let login = userinfo.preferred_username.clone(); - let email = userinfo.email.clone(); - h - let new_user = User { - id: userinfo.sub.clone().unwrap_or_default(), - login, - last_name: userinfo.family_name.clone(), - first_name: userinfo.name.clone(), - email, - activated: userinfo.email_verified, - image_url: userinfo.picture.clone().map(|x| x.to_string()), - lang_key: Some("en".to_string()), - authorities: vec!["ROLE_USER".to_string()], //FIXME: read from token - }; - */ - - // user = new_user.login.unwrap(); - username = userinfo.preferred_username.unwrap(); - } - Ok(None) => { - // return ExampleResponse::Unauthorized(rocket::response::status::Unauthorized(Some( - // "no id_token found", - // ))); - panic!("no id_token found"); - } - Err(err) => { - eprintln!("login error in call: {:?}", err); - // return ExampleResponse::Unauthorized(rocket::response::status::Unauthorized(Some( - // "login error in call", - // ))); - panic!("login error in call"); - } - } - - // Create our macaroon - let mut macaroon = match Macaroon::create(Some("location".into()), &key, username.into()) { - Ok(macaroon) => macaroon, - Err(error) => panic!("Error creating macaroon: {:?}", error), - }; - - let something = format!("time < {}", chrono::Utc::now().timestamp() + MAC_VALID_SECS).into(); - macaroon.add_first_party_caveat(something); - - let serialized = macaroon.serialize(macaroon::Format::V2).unwrap(); - let encoded = CUSTOM_ENGINE.encode(serialized); - - let redirect_url = - Url::parse_with_params(&decoded_state.after_auth, &[("loginToken", encoded)]).unwrap(); - - let redirect = axum::response::Redirect::to(&redirect_url.to_string()); - - Ok(redirect) -} diff --git a/src/config/mod.rs b/src/config/mod.rs index 64cc3165..1b92bf7a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -11,6 +11,8 @@ use tracing::warn; mod proxy; mod oidc; +pub use oidc::*; + use self::{oidc::OidcConfig, proxy::ProxyConfig}; #[derive(Clone, Debug, Deserialize)] diff --git a/src/config/oidc.rs b/src/config/oidc.rs index 1c6810e1..18baf25d 100644 --- a/src/config/oidc.rs +++ b/src/config/oidc.rs @@ -1,60 +1,79 @@ -use ruma::{serde::AsRefStr, OwnedMxcUri}; +use openidconnect::JsonWebKeyId; +use ruma::{ + api::client::session::get_login_types::v3::{IdentityProvider, IdentityProviderBrand}, + OwnedMxcUri, +}; use serde::Deserialize; pub type OidcConfig = Vec; #[derive(Clone, Debug, Deserialize)] -pub struct ProviderConfig { +pub struct Metadata { // Must be unique, used to distinguish OPs + #[serde(rename = "idp_id")] pub id: String, - pub name: Option, - pub icon: Option, - // Base URL of the OpenID Provider + #[serde(rename = "idp_name")] + pub name: Option, + + #[serde(rename = "idp_icon")] + pub icon: Option, +} + +impl Into for Metadata { + fn into(self) -> IdentityProvider { + let brand = match IdentityProviderBrand::from(self.id.clone()) { + IdentityProviderBrand::_Custom(_) => None, + brand => Some(brand), + }; + + IdentityProvider { + id: self.id.clone(), + name: self.name.unwrap_or(self.id), + icon: self.icon, + brand, + } + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ProviderConfig { + // Information retrieved while creating the OpenID Application + pub client: ClientConfig, + + // Information for displaying the OpenID Provider + #[serde(flatten)] + pub metadata: Metadata, + + // Foo + // #[serde(deserialize_with = "crate::utils::deserialize_from_str")] pub issuer: url::Url, - // Always contains at least "openid" + + // Always contains "openid" by default // "profile", "email" and "name" are useful to suggest an MXID pub scopes: Vec, + // PKCE provides dynamic client secrets // Should be enabled when `ClientAuthMethod` is `None` pub pkce: Option, - // Allow existent accounts to login with OIDC - pub allow_existing_users: bool, - // Invalidate user sessions when the OP session expires - pub backchannel_logout: bool, - // Should be enabled when the authorization response does not contain userinfo - pub userinfo_override: bool, // Should be enabled when the authorization response does not contain a unique subject claim - subject_claim: Option, + pub subject_claim: Option, - pub client: ClientConfig, - pub metadata: MetadataConfig, -} + // Allow existent accounts to login with OIDC + #[serde(default)] + pub allow_existing_users: bool, -#[derive(Clone, Debug, Deserialize)] -pub enum MetadataConfig { - // Should be used for OPs supporting the OIDC Discovery endpoint - Discoverable, - Manual { - authorization: Option, - token: Option, - userinfo: Option, - jwk: Option, - }, -} + // Invalidate user sessions when the OP session expires + #[serde(default)] + pub backchannel_logout: bool, -#[derive(Clone, Debug, Deserialize, AsRefStr)] -pub enum ClientAuthMethod { - None, - // Provide the client combo in the Authorization header - Basic, - // Provide the client combo as in the POST request body - Post, - // Provide a JWT signed with client secret - SharedJwt, - // Provide a JWT signed with our own keypair (OP needs to know the public key) - PrivateJwt, + // Should be enabled when the authorization response does not contain userinfo + #[serde(default)] + pub userinfo_override: bool, + + #[serde(default)] + pub discovery: DiscoveryConfig, } #[derive(Clone, Debug, Deserialize)] @@ -63,4 +82,51 @@ pub struct ClientConfig { // Mandatory for the following `ClientAuthMethod`s: // [`Basic`,`Post`,`SharedJwt`] pub secret: Option, + + pub auth_method: AuthMethod, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Endpoints { + pub auth: url::Url, + pub token: Option, + pub userinfo: Option, + pub jwk: Option, +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum DiscoveryConfig { + // Should be used for OPs supporting the OIDC Discovery endpoint + #[default] + Automatic, + Manual(Endpoints), +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AuthMethod { + None, + // Provide the client combo in the Authorization header + Basic, + // Provide the client combo as in the POST request body + Post, + // Provide a JWT signed with client secret + SharedJwt, + // Provide a JWT signed with a private key (OP needs to know the public key) + PrivateJwt, +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Algorithm { + Rsa, + EdDsa, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct PrivateSigningKey { + pub kind: Algorithm, + pub path: String, + pub kid: Option, } diff --git a/src/main.rs b/src/main.rs index 00c316e9..071fb1f4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -179,6 +179,7 @@ async fn run_server() -> io::Result<()> { header::CONTENT_TYPE, header::ACCEPT, header::AUTHORIZATION, + header::LOCATION, ]) .max_age(Duration::from_secs(86400)), ) @@ -400,11 +401,12 @@ fn routes() -> Router { "/_matrix/key/v2/server/:key_id", get(server_server::get_server_keys_deprecated_route), ) - .route( - "/_matrix/client/v3/login/sso/redirect", - get(client_server::get_sso_redirect), - ) - .route("/sso_return", get(client_server::get_sso_return)) + // .route( + // "/_matrix/client/v3/login/sso/redirect", + // get(client_server::get_sso_redirect), + // ) + .ruma_route(client_server::get_sso_redirect_with_idp_id) + // .route("/sso_return", get(client_server::get_sso_return)) .ruma_route(server_server::get_public_rooms_route) .ruma_route(server_server::get_public_rooms_filtered_route) .ruma_route(server_server::send_transaction_message_route) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index c8f28d38..07cd51f6 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -75,8 +75,6 @@ pub struct Service { pub rotate: RotationHandler, pub shutdown: AtomicBool, - - pub oidc: HashMap, pub macaroon: Option, } @@ -190,34 +188,6 @@ impl Service { .as_ref() .map(|s| macaroon::MacaroonKey::generate(s.as_bytes())); - let oidc = { - let discover_all = config.oidc.iter().map(|provider| { - openid::DiscoveredClient::discover_with_client( - default_client.clone(), - provider.client.id.clone(), - provider.client.secret.clone(), - Some(provider.redirect_url.to_string()), - provider.issuer.clone(), - ).map_ok(|client| (provider.id.clone(), client)) - }); - - let pairs = future::try_join_all(discover_all).await.map_err(|e| { - error!("failed to discover one or more OIDC providers: {}", e); - Error::bad_config("failed to discover one or more OIDC providers.") - })?; - - let mut result = HashMap::with_capacity(config.oidc.len()); - - for (id, client) in pairs { - let None = result.insert(id, client) else { - error!("OIDC providers must have unique IDs."); - return Err(Error::bad_config("OIDC providers must have unique IDs.")); - }; - } - - result - }; - let mut s = Self { db, config, @@ -249,7 +219,6 @@ impl Service { rotate: RotationHandler::new(), shutdown: AtomicBool::new(false), macaroon, - oidc, }; fs::create_dir_all(s.get_media_folder())?; diff --git a/src/service/mod.rs b/src/service/mod.rs index 504377c2..6895127c 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -13,6 +13,7 @@ pub mod appservice; pub mod globals; pub mod key_backups; pub mod media; +pub mod oidc; pub mod pdu; pub mod pusher; pub mod rooms; @@ -34,6 +35,7 @@ pub struct Services { pub key_backups: key_backups::Service, pub media: media::Service, pub sending: Arc, + pub oidc: Arc, } impl Services { @@ -115,6 +117,7 @@ impl Services { media: media::Service { db }, sending: sending::Service::build(db, &config), + oidc: oidc::Service::build(&config).await, globals: globals::Service::load(db, config).await?, }) } diff --git a/src/service/oidc/mod.rs b/src/service/oidc/mod.rs new file mode 100644 index 00000000..6a03d197 --- /dev/null +++ b/src/service/oidc/mod.rs @@ -0,0 +1,252 @@ +use std::sync::Arc; + +use macaroon::{Macaroon, MacaroonKey}; +use openidconnect::{ + core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata}, + reqwest::async_http_client, + AuthUrl, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, PkceCodeChallenge, RedirectUrl, + Scope, TokenUrl, UserInfoUrl, +}; +use tokio::sync::OnceCell; + +use crate::{ + config::{DiscoveryConfig as Discovery, Metadata, ProviderConfig}, + services, Config, Error, +}; + +pub const COOKIE_STATE_EXPIRATION_SECS: i64 = 10 * 60; + +#[derive(Clone)] +pub struct Client(Arc); + +#[derive(Clone)] +pub struct Provider { + pub config: ProviderConfig, + pub client: OnceCell, +} + +impl Provider { + pub fn new(config: ProviderConfig) -> Self { + Self { + config, + client: OnceCell::new(), + } + } + + pub async fn handle_redirect(&self, redirect_url: &str) -> (url::Url, String) { + let client = self + .client + .get_or_try_init(|| async { Client::new(self.config.clone()).await }) + .await + .map(|c| c.0.clone()) + .unwrap(); + let scopes = self + .config + .scopes + .iter() + .map(ToOwned::to_owned) + .map(Scope::new); + + let csrf = CsrfToken::new_random(); + let nonce = Nonce::new_random(); + let tmp = (csrf.clone(), nonce.clone()); + + let mut req = client + .authorize_url( + CoreAuthenticationFlow::Implicit(true), + move || csrf, + move || nonce, + ) + .add_scopes(scopes); + + let (challenge, verifier) = PkceCodeChallenge::new_random_sha256(); + if let Some(true) = self.config.pkce { + req = req.set_pkce_challenge(challenge); + } + + let (url, csrf, nonce) = req.url(); + + let cookie = self.generate_macaroon( + csrf.secret(), + nonce.secret(), + redirect_url, + self.config.pkce.map(|_| verifier.secret().as_str()), + ); + + (url, cookie) + } + + pub fn generate_macaroon( + &self, + state: &str, + nonce: &str, + redirect_url: &str, + pkce: Option<&str>, + ) -> String { + let key = services() + .globals + .macaroon + .unwrap_or_else(MacaroonKey::generate_random); + + let mut macaroon = Macaroon::create(None, &key, "oidc".into()).unwrap(); + let expires = chrono::Utc::now() + chrono::TimeDelta::seconds(COOKIE_STATE_EXPIRATION_SECS); + + let idp_id = self.config.metadata.id.as_str(); + + macaroon.add_first_party_caveat(format!("idp_id = {idp_id}").into()); + macaroon.add_first_party_caveat(format!("state = {state}").into()); + macaroon.add_first_party_caveat(format!("nonce = {nonce}").into()); + macaroon.add_first_party_caveat(format!("redirect_url = {redirect_url}").into()); + macaroon.add_first_party_caveat(format!("time < {expires}").into()); + + if let Some(verifier) = pkce { + macaroon.add_first_party_caveat(format!("verifier = {}", verifier).into()); + } + + macaroon.serialize(macaroon::Format::V2).unwrap() + } +} + +pub struct Service { + pub inner: Vec, +} + +impl Service { + pub async fn build(config: &Config) -> Arc { + Arc::new(Self { + inner: config.oidc.clone().into_iter().map(Provider::new).collect(), + }) + } + + pub async fn get_provider(&self, idp_id: impl AsRef) -> Result { + let Some(found) = self + .inner + .iter() + .find(|p| p.config.metadata.id == idp_id.as_ref()) + .map(Clone::clone) + else { + return Err(()); + }; + + // let client = found.client + // .get_or_try_init(|| async { Client::new(config.clone()).await }) + // .await + // .unwrap(); + + Ok(found) + } + + pub fn get_all(&self) -> &[Provider] { + self.inner.as_slice() + } + + // TODO + pub fn get_metadata(&self) -> Vec { + self.inner + .iter() + .map(|p| p.config.metadata.clone()) + .collect() + } + + // pub async fn generate_auth_url<'s, S>(&self, idp_id: String, scopes: S) -> () + // where + // S: Iterator, + // { + + // // TODO: PKCE challenge + // if false { + // let (challenge, verifier) = PkceCodeChallenge::new_random_sha256(); + + // // return auth_url.set_pkce_challenge(challenge).url() + // } + + // // auth_url.url() + // } + + // pub async fn exchange_auth_code(&self, auth_code: AuthorizationCode) { + // // -> Result { + // let provider = self.get_client("something").await.unwrap(); + + // let req = provider.exchange_code(auth_code); + // if false { + // // return req.set_pkce_verifier(pkce_verifier).request(http_client)? + // } + + // let resp = req.request(http_client).unwrap(); + + // let id_token = resp.id_token().unwrap(); + + // let access_token = resp.access_token(); + + // let claims = id_token + // .claims(&provider.id_token_verifier(), &Nonce::new("".into())) + // .unwrap(); + + // if let Some(hash) = claims.access_token_hash() { + // &AccessTokenHash::from_token(access_token, &id_token.signing_alg().unwrap()).unwrap() + // == hash; + // } + + // // need `UserInfo` endpoint + // // if let Some(subject) = config.subject_claim { + // if false { + // let req = provider + // .user_info( + // access_token.clone(), + // Some(SubjectIdentifier::new("id".to_owned())), + // ) + // .unwrap(); + // let ok: CoreUserInfoClaims = req.request(http_client).unwrap(); + // } + // } +} + +impl Client { + pub async fn new(config: ProviderConfig) -> Result { + let mut base_url = url::Url::try_from( + services() + .globals + .well_known_client() + .as_deref() + .unwrap_or(services().globals.server_name().as_str()), + ) + .expect("server_name should be a valid URL"); + + base_url.set_path("_conduit/client/oidc/callback"); + let redirect_url = RedirectUrl::from_url(base_url); + + let client = match config.discovery { + Discovery::Automatic => { + let url = config.issuer.to_string(); + let url = url.strip_suffix("/").unwrap(); + + let discovery = CoreProviderMetadata::discover_async( + // https://github.com/ramosbugs/openidconnect-rs/issues/77 + IssuerUrl::new(url.to_owned()).unwrap(), + async_http_client, + ) + .await + .unwrap(); + // .map_err(|e| Error::BadConfig(&e.to_string()))?; + + CoreClient::from_provider_metadata( + discovery, + ClientId::new(config.client.id), + config.client.secret.map(ClientSecret::new), + ) + } + Discovery::Manual(endpoints) => CoreClient::new( + ClientId::new(config.client.id), + config.client.secret.map(ClientSecret::new), + IssuerUrl::from_url(config.issuer), + AuthUrl::from_url(endpoints.auth), + endpoints.token.map(TokenUrl::from_url), + endpoints.userinfo.map(UserInfoUrl::from_url), + Default::default(), + ) + .set_redirect_uri(redirect_url), + }; + + Ok(Self(Arc::new(client))) + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 0b5b1ae4..2b268c26 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -4,13 +4,17 @@ use argon2::{Config, Variant}; use cmp::Ordering; use rand::prelude::*; use ring::digest; -use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; +use ruma::{ + canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, MxcUri, MxcUriError, +}; use std::{ cmp, fmt, str::FromStr, time::{SystemTime, UNIX_EPOCH}, }; +use crate::services; + pub fn millis_since_unix_epoch() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) @@ -180,3 +184,26 @@ impl<'a> fmt::Display for HtmlEscape<'a> { Ok(()) } } + +pub fn mxc_to_http_or_none(mxc_uri: Option<&MxcUri>, width: &str, height: &str) -> Option { + let Some(mxc_uri) = mxc_uri else { + return None; + }; + let base_url = services() + .globals + .well_known_client() + .as_deref() + .unwrap_or(services().globals.server_name().as_str()); + let (server_name, media_id) = mxc_uri.parts().ok()?; + + let mut host = + format!("https://{base_url}/_matrix/media/v3/thumbnail/{server_name}/{media_id}") + .parse::() + .expect("server_name should be a valid domain"); + host.query_pairs_mut() + .append_pair("width", width) + .append_pair("height", height) + .append_pair("method", "scale"); + + Some(host.to_string()) +}