Draft: SSO login (OAuth 2.0 + OpenID Connect) #1012

Open
avdb13 wants to merge 11 commits from oidc into next
14 changed files with 1201 additions and 586 deletions
Showing only changes of commit da0448f1b6 - Show all commits

1205
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -35,11 +35,11 @@ axum-extra = { version = "0.8.0", features = ["cookie"] }
axum-server = { version = "0.5.1", features = ["tls-rustls"] } axum-server = { version = "0.5.1", features = ["tls-rustls"] }
tower = { version = "0.4.13", features = ["util"] } tower = { version = "0.4.13", features = ["util"] }
tower-http = { version = "0.4.1", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] } tower-http = { version = "0.4.1", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] }
chrono = "0.4"
# Used for matrix spec type definitions and helpers # Used for matrix spec type definitions and helpers
#ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
ruma = { git = "https://github.com/ruma/ruma", rev = "1a1c61ee1e8f0936e956a3b69c931ce12ee28475", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] } #ruma = { git = "https://github.com/ruma/ruma", rev = "1a1c61ee1e8f0936e956a3b69c931ce12ee28475", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
ruma = { git = "https://github.com/avdb13/ruma", branch = "main", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
#ruma = { git = "https://github.com/timokoesters/ruma", rev = "4ec9c69bb7e09391add2382b3ebac97b6e8f4c64", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] } #ruma = { git = "https://github.com/timokoesters/ruma", rev = "4ec9c69bb7e09391add2382b3ebac97b6e8f4c64", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
#ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] } #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
@ -55,7 +55,8 @@ bytes = "1.4.0"
http = "0.2.9" http = "0.2.9"
# Used to find data directory for default db path # Used to find data directory for default db path
directories = "4.0.1" directories = "4.0.1"
macaroon = { git = "https://github.com/macaroon-rs/macaroon.git", branch = "main" } # Used for SSO authorization
macaroon = "0.3.0"
# Used for ruma wrapper # Used for ruma wrapper
serde_json = { version = "1.0.96", features = ["raw_value"] } serde_json = { version = "1.0.96", features = ["raw_value"] }
# Used for appservice registration files # Used for appservice registration files
@ -111,8 +112,6 @@ clap = { version = "4.3.0", default-features = false, features = ["std", "derive
futures-util = { version = "0.3.28", default-features = false } futures-util = { version = "0.3.28", default-features = false }
# Used for reading the configuration from conduit.toml & environment variables # Used for reading the configuration from conduit.toml & environment variables
figment = { version = "0.10.8", features = ["env", "toml"] } figment = { version = "0.10.8", features = ["env", "toml"] }
uuid = { version = "0.8", features = ["serde", "v4"] }
time = "0.3.22"
tikv-jemalloc-ctl = { version = "0.5.0", features = ["use_std"], optional = true } tikv-jemalloc-ctl = { version = "0.5.0", features = ["use_std"], optional = true }
tikv-jemallocator = { version = "0.5.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } tikv-jemallocator = { version = "0.5.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true }
@ -120,16 +119,22 @@ lazy_static = "1.4.0"
async-trait = "0.1.68" async-trait = "0.1.68"
sd-notify = { version = "0.4.1", optional = true } sd-notify = { version = "0.4.1", optional = true }
url = { version = "2.5.0", features = ["serde"] }
# Used for SSO through OIDC
openidconnect = { version = "3.5.0", features = ["jwk-alg", "accept-string-booleans"] } openidconnect = { version = "3.5.0", features = ["jwk-alg", "accept-string-booleans"] }
url = { version = "2.5.0", features = ["serde"] }
# Used for SSO as fallback for non-web clients
askama = { version = "0.12.1", features = ["with-axum"] } askama = { version = "0.12.1", features = ["with-axum"] }
askama_axum = { version = "0.4.0", features = ["urlencode", "config"] } askama_axum = { version = "0.4.0", features = ["urlencode", "config"] }
time = "0.3.34"
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
nix = { version = "0.26.2", features = ["resource"] } nix = { version = "0.26.2", features = ["resource"] }
[features] [features]
default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "systemd"] #default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "systemd"]
default = ["conduit_bin", "backend_rocksdb"]
#backend_sled = ["sled"] #backend_sled = ["sled"]
backend_persy = ["persy", "parking_lot"] backend_persy = ["persy", "parking_lot"]
backend_sqlite = ["sqlite"] backend_sqlite = ["sqlite"]
@ -173,7 +178,7 @@ systemd-units = { unit-name = "matrix-conduit" }
[profile.dev] [profile.dev]
lto = 'off' lto = 'off'
incremental = true incremental = false
[profile.release] [profile.release]
lto = 'thin' lto = 'thin'

View file

@ -20,10 +20,10 @@
# for more information # for more information
# YOU NEED TO EDIT THIS # YOU NEED TO EDIT THIS
#server_name = "your.server.name" server_name = "localhost:6167"
# This is the only directory where Conduit will save its data # This is the only directory where Conduit will save its data
database_path = "/var/lib/matrix-conduit/" database_path = "./db"
database_backend = "rocksdb" database_backend = "rocksdb"
# The port Conduit will be running on. You need to set up a reverse proxy in # The port Conduit will be running on. You need to set up a reverse proxy in
@ -51,22 +51,22 @@ enable_lightning_bolt = true
trusted_servers = ["matrix.org"] trusted_servers = ["matrix.org"]
#max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time
#log = "warn,state_res=warn,rocket=off,_=off,sled=off" log = "info"
address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy
#address = "0.0.0.0" # If Conduit is running in a container, make sure the reverse proxy (ie. Traefik) can reach it. #address = "0.0.0.0" # If Conduit is running in a container, make sure the reverse proxy (ie. Traefik) can reach it.
macaroon_key = "this is the key" macaroon_key = "this is the key" # Currently only used in SSO as short-term login token
[[global.oidc]] [[global.sso]]
idp_id = "gitlab" id = "gitlab"
idp_name = "Gitlab" name = "Gitlab"
idp_icon = "mxc://matrix.org/00000000000000000000000000000000" icon = "mxc://kurosaki.cx/KKbSvyoUEYXdrzXBJoOJoLpZbqFYrCpW"
issuer = "https://gitlab.com" issuer = "https://gitlab.com"
scopes = ["openid", "profile"] scopes = ["openid", "profile"]
[global.oidc.client] [global.sso.client]
id = "0000000000000000000000000000000000000000000000000000000000000000" id = "12dd00d057420beda06fd5edcd21287026dc0c66ba5c02d40c2eff8b559c6709"
secret = "0000000000000000000000000000000000000000000000000000000000000000" secret = "3a806573cacf5da560b8c720cf32019255908c83e31ba78d280cf08a4eb619fd"
auth_method = "post" auth_method = "post"

View file

@ -46,7 +46,7 @@
<ul class="providers"> <ul class="providers">
{% for idp in metadata %} {% for idp in metadata %}
<li> <li>
<a href="pick_idp?idp={{ idp.id }}&redirectUrl={{ redirect_url|urlencode_strict }}"> <a href="redirect/{{ idp.id }}&redirectUrl={{ redirect_url|urlencode_strict }}">
{% match crate::utils::mxc_to_http_or_none(idp.icon.as_deref(), "32", "32") %} {% match crate::utils::mxc_to_http_or_none(idp.icon.as_deref(), "32", "32") %}
{% when Some with (icon) %} {% when Some with (icon) %}
<img src="{{ icon }}"/> <img src="{{ icon }}"/>

View file

@ -11,7 +11,6 @@ mod keys;
mod media; mod media;
mod membership; mod membership;
mod message; mod message;
mod oidc;
mod presence; mod presence;
mod profile; mod profile;
mod push; mod push;
@ -22,6 +21,7 @@ mod report;
mod room; mod room;
mod search; mod search;
mod session; mod session;
mod sso;
mod space; mod space;
mod state; mod state;
mod sync; mod sync;
@ -47,7 +47,6 @@ pub use keys::*;
pub use media::*; pub use media::*;
pub use membership::*; pub use membership::*;
pub use message::*; pub use message::*;
pub use oidc::*;
pub use presence::*; pub use presence::*;
pub use profile::*; pub use profile::*;
pub use push::*; pub use push::*;
@ -58,6 +57,7 @@ pub use report::*;
pub use room::*; pub use room::*;
pub use search::*; pub use search::*;
pub use session::*; pub use session::*;
pub use sso::*;
pub use space::*; pub use space::*;
pub use state::*; pub use state::*;
pub use sync::*; pub use sync::*;

View file

@ -1,114 +0,0 @@
use crate::{
config::Metadata, service::oidc::COOKIE_STATE_EXPIRATION_SECS, services, Error, Result, Ruma,
};
use askama::Template;
use axum::response::IntoResponse;
use axum_extra::extract::cookie::{Cookie, SameSite};
use bytes::BufMut;
use http::{header::COOKIE, HeaderValue, StatusCode};
use ruma::api::{
client::{error::ErrorKind, session},
error::IntoHttpError,
OutgoingResponse,
};
// const SEED_LEN: usize = 32;
/// # `GET /_matrix/client/v3/login/sso/redirect`
///
/// Redirect user to SSO interface.
///
pub async fn get_sso_redirect(
body: Ruma<session::sso_login::v3::Request>,
) -> axum::response::Response {
let server_name = services().globals.server_name().to_string();
let metadata = services().oidc.get_metadata();
let redirect_url = body.redirect_url.clone();
let t = SsoTemplate {
server_name,
metadata,
redirect_url,
};
match t.render() {
Ok(body) => {
let headers = [(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static(SsoTemplate::MIME_TYPE),
)];
(headers, body).into_response()
}
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "woops").into_response(),
}
}
pub struct SsoResponse {
pub inner: session::sso_login_with_provider::v3::Response,
pub cookie: String,
}
impl OutgoingResponse for SsoResponse {
fn try_into_http_response<T: Default + BufMut>(
self,
) -> Result<http::Response<T>, IntoHttpError> {
self.inner.try_into_http_response().map(|mut ok| {
*ok.status_mut() = StatusCode::FOUND;
match HeaderValue::from_str(self.cookie.as_str()) {
Ok(value) => {
ok.headers_mut().insert(COOKIE, value);
Ok(ok)
}
Err(e) => Err(IntoHttpError::Header(e)),
}
})?
}
}
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
///
/// Redirect user to SSO interface.
pub async fn get_sso_redirect_with_idp_id(
body: Ruma<session::sso_login_with_provider::v3::Request>,
// State(uiaa_session): State<Option<()>>,
) -> Result<SsoResponse> {
// if services().oidc.get_all().len() == 1 {
// }
let Ok(provider) = services().oidc.get_provider(&body.idp_id).await else {
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Unknown identity provider",
));
};
let (location, cookie) = provider.handle_redirect(&body.redirect_url).await;
let inner = session::sso_login_with_provider::v3::Response {
location: location.to_string(),
};
let cookie = Cookie::build("openid-state", cookie)
.path("/_conduit/client/oidc")
.secure(false) //FIXME
// .secure(true)
.http_only(true)
.same_site(SameSite::None)
.max_age(time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
.finish()
.to_string();
Ok(SsoResponse { inner, cookie })
// Ok((axum::http::StatusCode::FOUND, [(LOCATION, &body.redirect_url)]).into_response())
}
pub async fn get_sso_return() {}
#[derive(Template)]
#[template(path = "sso_login_idp_picker.html", escape = "none")]
struct SsoTemplate {
pub server_name: String,
pub metadata: Vec<Metadata>,
pub redirect_url: String,
}

View file

@ -1,20 +1,15 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma}; use crate::{services, utils, Error, Result, Ruma};
use base64::{alphabet, engine, engine::general_purpose};
// use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use macaroon::Verifier;
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
session::{get_login_types, login, logout, logout_all}, session::{get_login_types, login, logout, logout_all},
uiaa::UserIdentifier, uiaa::UserIdentifier,
}, },
events::GlobalAccountDataEventType, UserId,
push, UserId,
}; };
use serde::Deserialize; use serde::Deserialize;
use tracing::{debug, error, info, warn}; use tracing::{info, warn};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Claims { struct Claims {
@ -22,56 +17,6 @@ struct Claims {
//exp: usize, //exp: usize,
} }
#[tracing::instrument]
fn verifier_callback(v: &macaroon::ByteString) -> bool {
use std::num::ParseIntError;
let result: Result<bool, String> = (|| {
if v.0.starts_with(b"time < ") {
let v2 = std::str::from_utf8(&v.0).map_err(|e| e.to_string())?;
let v3 = v2.trim_start_matches("time < ");
let v4: i64 = v3.parse().map_err(|e: ParseIntError| e.to_string())?;
let now = chrono::Utc::now().timestamp();
if now < v4 {
debug!("macaroon is not expired yet");
Ok(true)
} else {
debug!(
"macaroon expired, v4={} , now={}, v4-now={}",
v4,
now,
v4 - now
);
Ok(false)
}
} else {
Ok(false)
}
})();
match result {
Ok(r) => r,
Err(e) => {
error!("verifier_callback: {:?}", e);
false
}
}
}
#[test]
fn test_verifier_callback() {
use macaroon::ByteString;
let now = chrono::Utc::now().timestamp();
assert!(verifier_callback(&ByteString(
format!("time < {}", now + 10).as_bytes().to_vec()
)));
assert!(!verifier_callback(&ByteString(
format!("time < {}", now - 10).as_bytes().to_vec()
)));
}
/// # `GET /_matrix/client/r0/login` /// # `GET /_matrix/client/r0/login`
/// ///
/// Get the supported login types of this server. One of these should be used as the `type` field /// Get the supported login types of this server. One of these should be used as the `type` field
@ -80,11 +25,10 @@ pub async fn get_login_types_route(
_body: Ruma<get_login_types::v3::Request>, _body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> { ) -> Result<get_login_types::v3::Response> {
let identity_providers = services() let identity_providers = services()
.oidc .sso
.get_metadata() .inner
.clone() .iter()
.into_iter() .map(|p| p.config.clone().into())
.map(Into::into)
.collect(); .collect();
Ok(get_login_types::v3::Response::new(vec![ Ok(get_login_types::v3::Response::new(vec![
@ -160,9 +104,6 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
user_id user_id
} }
login::v3::LoginInfo::Token(login::v3::Token { token }) => { login::v3::LoginInfo::Token(login::v3::Token { token }) => {
const CUSTOM_ENGINE: engine::GeneralPurpose =
engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD);
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
let token = jsonwebtoken::decode::<Claims>( let token = jsonwebtoken::decode::<Claims>(
token, token,
@ -174,57 +115,6 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
UserId::parse_with_server_name(username, services().globals.server_name()).map_err( UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."), |_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."),
)? )?
} else if macaroon::Macaroon::deserialize(&CUSTOM_ENGINE.decode(token).unwrap()) // TODO
.is_ok()
{
println!("TOKEN! {}", token);
let macaroon =
macaroon::Macaroon::deserialize(&CUSTOM_ENGINE.decode(token).unwrap()).unwrap();
let v1 = macaroon.identifier();
let user_id = std::str::from_utf8(&v1.0).unwrap();
println!("identifier: {}", user_id);
println!("location: {:?}", macaroon.location());
println!("sig: {:?}", macaroon.signature());
let mut verifier = Verifier::default();
verifier.satisfy_general(verifier_callback);
// let openid_client = &services().globals.openid_client;
// let (key, _client) = openid_client.as_ref().unwrap();
// match verifier.verify(&macaroon, &key, Default::default()) {
// Ok(()) => println!("Macaroon verified!"),
// Err(error) => println!("Error validating macaroon: {:?}", error),
// }
let user_id =
UserId::parse_with_server_name(user_id, services().globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?;
println!("user_id: {}", user_id);
if !services().users.exists(&user_id)? {
let random_password = crate::utils::random_string(TOKEN_LENGTH);
services().users.create(&user_id, Some(&random_password))?;
services().account_data.update(
None,
&user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: push::Ruleset::server_default(&user_id),
},
})
.expect("to json always works"),
)?;
}
user_id
} else { } else {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Unknown, ErrorKind::Unknown,

View file

@ -0,0 +1,102 @@
use crate::{
service::sso::{templates, COOKIE_STATE_EXPIRATION_SECS},
services, Error, Ruma, RumaResponse,
};
use askama::Template;
use axum::{body::Full, response::IntoResponse};
use axum_extra::extract::cookie::{Cookie, SameSite};
use bytes::BytesMut;
use http::StatusCode;
use ruma::api::{
client::{error::ErrorKind, session},
OutgoingResponse,
};
/// # `GET /_matrix/client/v3/login/sso/redirect`
///
/// Redirect user to SSO interface. The path argument is optional.
pub async fn get_sso_redirect(
body: Ruma<session::sso_login::v3::Request>,
) -> axum::response::Response {
if services().sso.get_all().is_empty() {
return Error::BadRequest(ErrorKind::NotFound, "SSO has not been configured")
.into_response();
}
return get_sso_fallback_template(body.redirect_url.as_deref().unwrap_or_default()).into_response();
}
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
///
/// Redirect user to SSO interface.
pub async fn get_sso_redirect_with_provider(
// State(uiaa_session): State<Option<()>>,
body: Ruma<session::sso_login_with_provider::v3::Request>,
) -> axum::response::Response {
if services().sso.get_all().is_empty() {
return Error::BadRequest(ErrorKind::NotFound, "SSO has not been configured")
.into_response();
}
if body.idp_id.is_empty() {
return get_sso_fallback_template(body.redirect_url.as_deref().unwrap_or_default()).into_response();
};
let Some(provider) = services().sso.get_provider(&body.idp_id) else {
return Error::BadRequest(ErrorKind::NotFound, "Unknown identity provider").into_response();
};
let (location, cookie) = provider.handle_redirect(body.redirect_url.as_deref().unwrap_or_default()).await;
let cookie = Cookie::build("openid-state", cookie)
.path("/_conduit/client/sso")
// .secure(false) //FIXME
.secure(true)
.http_only(true)
.same_site(SameSite::None)
.max_age(time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
.finish()
.to_string();
let mut res = session::sso_login_with_provider::v3::Response {
location: Some(location.to_string()),
cookie: Some(cookie),
}
.try_into_http_response::<BytesMut>()
.unwrap();
*res.status_mut() = StatusCode::FOUND;
res.map(BytesMut::freeze).map(Full::new).into_response()
}
fn get_sso_fallback_template(redirect_url: &str) -> axum::response::Response {
let server_name = services().globals.server_name().to_string();
let metadata = services().sso.inner.iter().map(Into::into).collect();
let redirect_url = redirect_url.to_string();
let t = templates::IdpPicker {
server_name,
metadata,
redirect_url,
};
t.render()
.map(|body| {
((
[(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static(templates::IdpPicker::MIME_TYPE),
)],
body,
))
.into_response()
})
.expect("woops")
}
/// # `GET /_conduit/client/oidc/callback`
///
/// Verify the response received from the identity provider.
/// If everything is fine redirect
pub async fn get_sso_callback() {}

View file

@ -9,11 +9,11 @@ use serde::{de::IgnoredAny, Deserialize};
use tracing::warn; use tracing::warn;
mod proxy; mod proxy;
mod oidc; mod sso;
pub use oidc::*; use self::proxy::ProxyConfig;
use self::{oidc::OidcConfig, proxy::ProxyConfig}; pub use sso::*;
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Config { pub struct Config {
@ -85,7 +85,7 @@ pub struct Config {
#[serde(default)] #[serde(default)]
pub macaroon_key: Option<String>, pub macaroon_key: Option<String>,
#[serde(default)] #[serde(default)]
pub oidc: OidcConfig, pub sso: Vec<ProviderConfig>,
pub emergency_password: Option<String>, pub emergency_password: Option<String>,

View file

@ -1,51 +1,18 @@
use openidconnect::JsonWebKeyId; use openidconnect::JsonWebKeyId;
use ruma::{ use ruma::OwnedMxcUri;
api::client::session::get_login_types::v3::{IdentityProvider, IdentityProviderBrand},
OwnedMxcUri,
};
use serde::Deserialize; use serde::Deserialize;
pub type OidcConfig = Vec<ProviderConfig>;
#[derive(Clone, Debug, Deserialize)]
pub struct Metadata {
// Must be unique, used to distinguish OPs
#[serde(rename = "idp_id")]
pub id: String,
#[serde(rename = "idp_name")]
pub name: Option<String>,
#[serde(rename = "idp_icon")]
pub icon: Option<OwnedMxcUri>,
}
impl Into<IdentityProvider> for Metadata {
fn into(self) -> IdentityProvider {
let brand = match IdentityProviderBrand::from(self.id.clone()) {
IdentityProviderBrand::_Custom(_) => None,
brand => Some(brand),
};
IdentityProvider {
id: self.id.clone(),
name: self.name.unwrap_or(self.id),
icon: self.icon,
brand,
}
}
}
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct ProviderConfig { pub struct ProviderConfig {
pub id: String,
pub name: Option<String>,
pub icon: Option<OwnedMxcUri>,
// Information retrieved while creating the OpenID Application // Information retrieved while creating the OpenID Application
pub client: ClientConfig, pub client: ClientConfig,
// Information for displaying the OpenID Provider
#[serde(flatten)]
pub metadata: Metadata,
// Foo
// #[serde(deserialize_with = "crate::utils::deserialize_from_str")] // #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
pub issuer: url::Url, pub issuer: url::Url,

View file

@ -24,7 +24,7 @@ use ruma::api::{
IncomingRequest, IncomingRequest,
}; };
use tokio::signal; use tokio::signal;
use tower::ServiceBuilder; use tower::{ServiceBuilder, ServiceExt};
use tower_http::{ use tower_http::{
cors::{self, CorsLayer}, cors::{self, CorsLayer},
trace::TraceLayer, trace::TraceLayer,
@ -401,11 +401,22 @@ fn routes() -> Router {
"/_matrix/key/v2/server/:key_id", "/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route), get(server_server::get_server_keys_deprecated_route),
) )
// .route( .route(
// "/_matrix/client/v3/login/sso/redirect", "/_matrix/client/v3/login/sso/redirect",
// get(client_server::get_sso_redirect), get(client_server::get_sso_redirect),
// ) )
.ruma_route(client_server::get_sso_redirect_with_idp_id) .route(
"/_matrix/client/v3/login/sso/redirect/",
get(client_server::get_sso_redirect),
)
.route(
"/_matrix/client/v3/login/sso/redirect/:idp_id",
get(client_server::get_sso_redirect_with_provider),
)
.route(
"/_conduit/v3/login/sso/callback",
get(client_server::get_sso_callback),
)
// .route("/sso_return", get(client_server::get_sso_return)) // .route("/sso_return", get(client_server::get_sso_return))
.ruma_route(server_server::get_public_rooms_route) .ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route) .ruma_route(server_server::get_public_rooms_filtered_route)

View file

@ -13,7 +13,7 @@ pub mod appservice;
pub mod globals; pub mod globals;
pub mod key_backups; pub mod key_backups;
pub mod media; pub mod media;
pub mod oidc; pub mod sso;
pub mod pdu; pub mod pdu;
pub mod pusher; pub mod pusher;
pub mod rooms; pub mod rooms;
@ -35,7 +35,7 @@ pub struct Services {
pub key_backups: key_backups::Service, pub key_backups: key_backups::Service,
pub media: media::Service, pub media: media::Service,
pub sending: Arc<sending::Service>, pub sending: Arc<sending::Service>,
pub oidc: Arc<oidc::Service>, pub sso: Arc<sso::Service>,
} }
impl Services { impl Services {
@ -116,8 +116,7 @@ impl Services {
key_backups: key_backups::Service { db }, key_backups: key_backups::Service { db },
media: media::Service { db }, media: media::Service { db },
sending: sending::Service::build(db, &config), sending: sending::Service::build(db, &config),
sso: sso::Service::build(&config).await,
oidc: oidc::Service::build(&config).await,
globals: globals::Service::load(db, config).await?, globals: globals::Service::load(db, config).await?,
}) })
} }

View file

@ -7,14 +7,14 @@ use openidconnect::{
AuthUrl, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, PkceCodeChallenge, RedirectUrl, AuthUrl, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, PkceCodeChallenge, RedirectUrl,
Scope, TokenUrl, UserInfoUrl, Scope, TokenUrl, UserInfoUrl,
}; };
use ruma::api::client::session::get_login_types::v3::{IdentityProvider, IdentityProviderBrand};
use tokio::sync::OnceCell; use tokio::sync::OnceCell;
use crate::{ use crate::{config::{DiscoveryConfig as Discovery, ProviderConfig}, services, Config, Error};
config::{DiscoveryConfig as Discovery, Metadata, ProviderConfig},
services, Config, Error,
};
pub const COOKIE_STATE_EXPIRATION_SECS: i64 = 10 * 60; pub const COOKIE_STATE_EXPIRATION_SECS: i64 = 60 * 60;
pub mod templates;
#[derive(Clone)] #[derive(Clone)]
pub struct Client(Arc<CoreClient>); pub struct Client(Arc<CoreClient>);
@ -47,15 +47,11 @@ impl Provider {
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
.map(Scope::new); .map(Scope::new);
let csrf = CsrfToken::new_random();
let nonce = Nonce::new_random();
let tmp = (csrf.clone(), nonce.clone());
let mut req = client let mut req = client
.authorize_url( .authorize_url(
CoreAuthenticationFlow::Implicit(true), CoreAuthenticationFlow::Implicit(true),
move || csrf, || CsrfToken::new_random_len(36),
move || nonce, || Nonce::new_random_len(36),
) )
.add_scopes(scopes); .add_scopes(scopes);
@ -88,10 +84,12 @@ impl Provider {
.macaroon .macaroon
.unwrap_or_else(MacaroonKey::generate_random); .unwrap_or_else(MacaroonKey::generate_random);
let mut macaroon = Macaroon::create(None, &key, "oidc".into()).unwrap(); let mut macaroon = Macaroon::create(None, &key, "sso".into()).unwrap();
let expires = chrono::Utc::now() + chrono::TimeDelta::seconds(COOKIE_STATE_EXPIRATION_SECS); let expires = (time::OffsetDateTime::now_utc()
+ time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
.to_string();
let idp_id = self.config.metadata.id.as_str(); let idp_id = self.config.id.as_str();
macaroon.add_first_party_caveat(format!("idp_id = {idp_id}").into()); macaroon.add_first_party_caveat(format!("idp_id = {idp_id}").into());
macaroon.add_first_party_caveat(format!("state = {state}").into()); macaroon.add_first_party_caveat(format!("state = {state}").into());
@ -107,6 +105,23 @@ impl Provider {
} }
} }
impl Into<IdentityProvider> for ProviderConfig {
fn into(self) -> IdentityProvider {
let brand = match IdentityProviderBrand::from(self.id.clone()) {
IdentityProviderBrand::_Custom(_) => None,
brand => Some(brand),
};
IdentityProvider {
id: self.id.clone(),
name: self.name.unwrap_or(self.id),
icon: self.icon,
brand,
}
}
}
pub struct Service { pub struct Service {
pub inner: Vec<Provider>, pub inner: Vec<Provider>,
} }
@ -114,39 +129,22 @@ pub struct Service {
impl Service { impl Service {
pub async fn build(config: &Config) -> Arc<Self> { pub async fn build(config: &Config) -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
inner: config.oidc.clone().into_iter().map(Provider::new).collect(), inner: config.sso.clone().into_iter().map(Provider::new).collect(),
}) })
} }
pub async fn get_provider(&self, idp_id: impl AsRef<str>) -> Result<Provider, ()> { pub fn get_provider(&self, idp_id: impl AsRef<str>) -> Option<Provider> {
let Some(found) = self self.inner
.inner
.iter() .iter()
.find(|p| p.config.metadata.id == idp_id.as_ref()) .find(|p| p.config.id == idp_id.as_ref())
.map(Clone::clone) .map(ToOwned::to_owned)
else {
return Err(());
};
// let client = found.client
// .get_or_try_init(|| async { Client::new(config.clone()).await })
// .await
// .unwrap();
Ok(found)
} }
pub fn get_all(&self) -> &[Provider] { pub fn get_all(&self) -> &[Provider] {
self.inner.as_slice() self.inner.as_slice()
} }
// TODO pub fn validate_session(&self) {}
pub fn get_metadata(&self) -> Vec<Metadata> {
self.inner
.iter()
.map(|p| p.config.metadata.clone())
.collect()
}
// pub async fn generate_auth_url<'s, S>(&self, idp_id: String, scopes: S) -> () // pub async fn generate_auth_url<'s, S>(&self, idp_id: String, scopes: S) -> ()
// where // where
@ -212,7 +210,7 @@ impl Client {
) )
.expect("server_name should be a valid URL"); .expect("server_name should be a valid URL");
base_url.set_path("_conduit/client/oidc/callback"); base_url.set_path("_conduit/client/sso/callback");
let redirect_url = RedirectUrl::from_url(base_url); let redirect_url = RedirectUrl::from_url(base_url);
let client = match config.discovery { let client = match config.discovery {

View file

@ -1,32 +1,50 @@
use askama::Template; use askama::Template;
use ruma::{OwnedUserId, api::client::search::search_events::v3::UserProfile, OwnedServerName}; use ruma::{
api::client::search::search_events::v3::UserProfile, OwnedMxcUri, OwnedServerName, OwnedUserId,
};
use crate::config::Metadata; use super::Provider;
#[derive(Template)] #[derive(Template)]
#[template(path = "auth_confirmation.html", escape = "none")] #[template(path = "auth_confirmation.html", escape = "none")]
pub struct AuthConfirmationTemplate { pub struct AuthConfirmation {
description: String, description: String,
redirect_url: url::Url, redirect_url: url::Url,
idp_name: String, idp_name: String,
} }
pub struct Metadata {
id: String,
name: Option<String>,
icon: Option<OwnedMxcUri>,
}
impl From<&Provider> for Metadata {
fn from(value: &Provider) -> Self {
Self {
id: value.config.id.clone(),
name: value.config.name.clone(),
icon: value.config.icon.clone(),
}
}
}
#[derive(Template)] #[derive(Template)]
#[template(path = "auth_failure.html", escape = "none")] #[template(path = "auth_failure.html", escape = "none")]
pub struct AuthFailureTemplate { pub struct AuthFailure {
server_name: OwnedServerName, server_name: OwnedServerName,
} }
#[derive(Template)] #[derive(Template)]
#[template(path = "auth_success.html", escape = "none")] #[template(path = "auth_success.html", escape = "none")]
pub struct AuthSuccessTemplate {} pub struct AuthSuccess {}
#[derive(Template)] #[derive(Template)]
#[template(path = "deactivated.html", escape = "none")] #[template(path = "deactivated.html", escape = "none")]
pub struct DeactivatedTemplate {} pub struct Deactivated {}
#[derive(Template)] #[derive(Template)]
#[template(path = "idp_picker.html", escape = "none")] #[template(path = "idp_picker.html", escape = "none")]
pub struct IdpPickerTemplate { pub struct IdpPicker {
pub server_name: String, pub server_name: String,
pub metadata: Vec<Metadata>, pub metadata: Vec<Metadata>,
pub redirect_url: String, pub redirect_url: String,
@ -34,7 +52,7 @@ pub struct IdpPickerTemplate {
#[derive(Template)] #[derive(Template)]
#[template(path = "registration.html", escape = "none")] #[template(path = "registration.html", escape = "none")]
pub struct RegistrationTemplate { pub struct Registration {
pub server_name: OwnedServerName, pub server_name: OwnedServerName,
pub idp: Metadata, pub idp: Metadata,
pub user: Attributes, pub user: Attributes,
@ -49,7 +67,7 @@ pub struct Attributes {
#[derive(Template)] #[derive(Template)]
#[template(path = "redirect_confirm.html", escape = "none")] #[template(path = "redirect_confirm.html", escape = "none")]
pub struct RedirectConfirmTemplate { pub struct RedirectConfirm {
pub user_id: OwnedUserId, pub user_id: OwnedUserId,
pub user_profile: UserProfile, pub user_profile: UserProfile,
pub display_url: url::Url, pub display_url: url::Url,