Matrix endpoints working fine

This commit is contained in:
mikoto 2024-02-26 13:27:15 +00:00
parent e4bda7bb42
commit da0448f1b6
14 changed files with 1201 additions and 586 deletions

1205
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -35,11 +35,11 @@ axum-extra = { version = "0.8.0", features = ["cookie"] }
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
tower = { version = "0.4.13", features = ["util"] }
tower-http = { version = "0.4.1", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] }
chrono = "0.4"
# Used for matrix spec type definitions and helpers
#ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
ruma = { git = "https://github.com/ruma/ruma", rev = "1a1c61ee1e8f0936e956a3b69c931ce12ee28475", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
#ruma = { git = "https://github.com/ruma/ruma", rev = "1a1c61ee1e8f0936e956a3b69c931ce12ee28475", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
ruma = { git = "https://github.com/avdb13/ruma", branch = "main", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
#ruma = { git = "https://github.com/timokoesters/ruma", rev = "4ec9c69bb7e09391add2382b3ebac97b6e8f4c64", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
#ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
@ -55,7 +55,8 @@ bytes = "1.4.0"
http = "0.2.9"
# Used to find data directory for default db path
directories = "4.0.1"
macaroon = { git = "https://github.com/macaroon-rs/macaroon.git", branch = "main" }
# Used for SSO authorization
macaroon = "0.3.0"
# Used for ruma wrapper
serde_json = { version = "1.0.96", features = ["raw_value"] }
# Used for appservice registration files
@ -111,8 +112,6 @@ clap = { version = "4.3.0", default-features = false, features = ["std", "derive
futures-util = { version = "0.3.28", default-features = false }
# Used for reading the configuration from conduit.toml & environment variables
figment = { version = "0.10.8", features = ["env", "toml"] }
uuid = { version = "0.8", features = ["serde", "v4"] }
time = "0.3.22"
tikv-jemalloc-ctl = { version = "0.5.0", features = ["use_std"], optional = true }
tikv-jemallocator = { version = "0.5.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true }
@ -120,16 +119,22 @@ lazy_static = "1.4.0"
async-trait = "0.1.68"
sd-notify = { version = "0.4.1", optional = true }
url = { version = "2.5.0", features = ["serde"] }
# Used for SSO through OIDC
openidconnect = { version = "3.5.0", features = ["jwk-alg", "accept-string-booleans"] }
url = { version = "2.5.0", features = ["serde"] }
# Used for SSO as fallback for non-web clients
askama = { version = "0.12.1", features = ["with-axum"] }
askama_axum = { version = "0.4.0", features = ["urlencode", "config"] }
time = "0.3.34"
[target.'cfg(unix)'.dependencies]
nix = { version = "0.26.2", features = ["resource"] }
[features]
default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "systemd"]
#default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "systemd"]
default = ["conduit_bin", "backend_rocksdb"]
#backend_sled = ["sled"]
backend_persy = ["persy", "parking_lot"]
backend_sqlite = ["sqlite"]
@ -173,7 +178,7 @@ systemd-units = { unit-name = "matrix-conduit" }
[profile.dev]
lto = 'off'
incremental = true
incremental = false
[profile.release]
lto = 'thin'

View file

@ -20,10 +20,10 @@
# for more information
# YOU NEED TO EDIT THIS
#server_name = "your.server.name"
server_name = "localhost:6167"
# This is the only directory where Conduit will save its data
database_path = "/var/lib/matrix-conduit/"
database_path = "./db"
database_backend = "rocksdb"
# The port Conduit will be running on. You need to set up a reverse proxy in
@ -51,22 +51,22 @@ enable_lightning_bolt = true
trusted_servers = ["matrix.org"]
#max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time
#log = "warn,state_res=warn,rocket=off,_=off,sled=off"
log = "info"
address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy
#address = "0.0.0.0" # If Conduit is running in a container, make sure the reverse proxy (ie. Traefik) can reach it.
macaroon_key = "this is the key"
macaroon_key = "this is the key" # Currently only used in SSO as short-term login token
[[global.oidc]]
idp_id = "gitlab"
idp_name = "Gitlab"
idp_icon = "mxc://matrix.org/00000000000000000000000000000000"
[[global.sso]]
id = "gitlab"
name = "Gitlab"
icon = "mxc://kurosaki.cx/KKbSvyoUEYXdrzXBJoOJoLpZbqFYrCpW"
issuer = "https://gitlab.com"
scopes = ["openid", "profile"]
[global.oidc.client]
id = "0000000000000000000000000000000000000000000000000000000000000000"
secret = "0000000000000000000000000000000000000000000000000000000000000000"
[global.sso.client]
id = "12dd00d057420beda06fd5edcd21287026dc0c66ba5c02d40c2eff8b559c6709"
secret = "3a806573cacf5da560b8c720cf32019255908c83e31ba78d280cf08a4eb619fd"
auth_method = "post"

View file

@ -46,7 +46,7 @@
<ul class="providers">
{% for idp in metadata %}
<li>
<a href="pick_idp?idp={{ idp.id }}&redirectUrl={{ redirect_url|urlencode_strict }}">
<a href="redirect/{{ idp.id }}&redirectUrl={{ redirect_url|urlencode_strict }}">
{% match crate::utils::mxc_to_http_or_none(idp.icon.as_deref(), "32", "32") %}
{% when Some with (icon) %}
<img src="{{ icon }}"/>

View file

@ -11,7 +11,6 @@ mod keys;
mod media;
mod membership;
mod message;
mod oidc;
mod presence;
mod profile;
mod push;
@ -22,6 +21,7 @@ mod report;
mod room;
mod search;
mod session;
mod sso;
mod space;
mod state;
mod sync;
@ -47,7 +47,6 @@ pub use keys::*;
pub use media::*;
pub use membership::*;
pub use message::*;
pub use oidc::*;
pub use presence::*;
pub use profile::*;
pub use push::*;
@ -58,6 +57,7 @@ pub use report::*;
pub use room::*;
pub use search::*;
pub use session::*;
pub use sso::*;
pub use space::*;
pub use state::*;
pub use sync::*;

View file

@ -1,114 +0,0 @@
use crate::{
config::Metadata, service::oidc::COOKIE_STATE_EXPIRATION_SECS, services, Error, Result, Ruma,
};
use askama::Template;
use axum::response::IntoResponse;
use axum_extra::extract::cookie::{Cookie, SameSite};
use bytes::BufMut;
use http::{header::COOKIE, HeaderValue, StatusCode};
use ruma::api::{
client::{error::ErrorKind, session},
error::IntoHttpError,
OutgoingResponse,
};
// const SEED_LEN: usize = 32;
/// # `GET /_matrix/client/v3/login/sso/redirect`
///
/// Redirect user to SSO interface.
///
pub async fn get_sso_redirect(
body: Ruma<session::sso_login::v3::Request>,
) -> axum::response::Response {
let server_name = services().globals.server_name().to_string();
let metadata = services().oidc.get_metadata();
let redirect_url = body.redirect_url.clone();
let t = SsoTemplate {
server_name,
metadata,
redirect_url,
};
match t.render() {
Ok(body) => {
let headers = [(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static(SsoTemplate::MIME_TYPE),
)];
(headers, body).into_response()
}
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "woops").into_response(),
}
}
pub struct SsoResponse {
pub inner: session::sso_login_with_provider::v3::Response,
pub cookie: String,
}
impl OutgoingResponse for SsoResponse {
fn try_into_http_response<T: Default + BufMut>(
self,
) -> Result<http::Response<T>, IntoHttpError> {
self.inner.try_into_http_response().map(|mut ok| {
*ok.status_mut() = StatusCode::FOUND;
match HeaderValue::from_str(self.cookie.as_str()) {
Ok(value) => {
ok.headers_mut().insert(COOKIE, value);
Ok(ok)
}
Err(e) => Err(IntoHttpError::Header(e)),
}
})?
}
}
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
///
/// Redirect user to SSO interface.
pub async fn get_sso_redirect_with_idp_id(
body: Ruma<session::sso_login_with_provider::v3::Request>,
// State(uiaa_session): State<Option<()>>,
) -> Result<SsoResponse> {
// if services().oidc.get_all().len() == 1 {
// }
let Ok(provider) = services().oidc.get_provider(&body.idp_id).await else {
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Unknown identity provider",
));
};
let (location, cookie) = provider.handle_redirect(&body.redirect_url).await;
let inner = session::sso_login_with_provider::v3::Response {
location: location.to_string(),
};
let cookie = Cookie::build("openid-state", cookie)
.path("/_conduit/client/oidc")
.secure(false) //FIXME
// .secure(true)
.http_only(true)
.same_site(SameSite::None)
.max_age(time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
.finish()
.to_string();
Ok(SsoResponse { inner, cookie })
// Ok((axum::http::StatusCode::FOUND, [(LOCATION, &body.redirect_url)]).into_response())
}
pub async fn get_sso_return() {}
#[derive(Template)]
#[template(path = "sso_login_idp_picker.html", escape = "none")]
struct SsoTemplate {
pub server_name: String,
pub metadata: Vec<Metadata>,
pub redirect_url: String,
}

View file

@ -1,20 +1,15 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
use base64::{alphabet, engine, engine::general_purpose};
// use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use macaroon::Verifier;
use ruma::{
api::client::{
error::ErrorKind,
session::{get_login_types, login, logout, logout_all},
uiaa::UserIdentifier,
},
events::GlobalAccountDataEventType,
push, UserId,
UserId,
};
use serde::Deserialize;
use tracing::{debug, error, info, warn};
use tracing::{info, warn};
#[derive(Debug, Deserialize)]
struct Claims {
@ -22,56 +17,6 @@ struct Claims {
//exp: usize,
}
#[tracing::instrument]
fn verifier_callback(v: &macaroon::ByteString) -> bool {
use std::num::ParseIntError;
let result: Result<bool, String> = (|| {
if v.0.starts_with(b"time < ") {
let v2 = std::str::from_utf8(&v.0).map_err(|e| e.to_string())?;
let v3 = v2.trim_start_matches("time < ");
let v4: i64 = v3.parse().map_err(|e: ParseIntError| e.to_string())?;
let now = chrono::Utc::now().timestamp();
if now < v4 {
debug!("macaroon is not expired yet");
Ok(true)
} else {
debug!(
"macaroon expired, v4={} , now={}, v4-now={}",
v4,
now,
v4 - now
);
Ok(false)
}
} else {
Ok(false)
}
})();
match result {
Ok(r) => r,
Err(e) => {
error!("verifier_callback: {:?}", e);
false
}
}
}
#[test]
fn test_verifier_callback() {
use macaroon::ByteString;
let now = chrono::Utc::now().timestamp();
assert!(verifier_callback(&ByteString(
format!("time < {}", now + 10).as_bytes().to_vec()
)));
assert!(!verifier_callback(&ByteString(
format!("time < {}", now - 10).as_bytes().to_vec()
)));
}
/// # `GET /_matrix/client/r0/login`
///
/// Get the supported login types of this server. One of these should be used as the `type` field
@ -80,11 +25,10 @@ pub async fn get_login_types_route(
_body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> {
let identity_providers = services()
.oidc
.get_metadata()
.clone()
.into_iter()
.map(Into::into)
.sso
.inner
.iter()
.map(|p| p.config.clone().into())
.collect();
Ok(get_login_types::v3::Response::new(vec![
@ -160,9 +104,6 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
user_id
}
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
const CUSTOM_ENGINE: engine::GeneralPurpose =
engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD);
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
let token = jsonwebtoken::decode::<Claims>(
token,
@ -174,57 +115,6 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."),
)?
} else if macaroon::Macaroon::deserialize(&CUSTOM_ENGINE.decode(token).unwrap()) // TODO
.is_ok()
{
println!("TOKEN! {}", token);
let macaroon =
macaroon::Macaroon::deserialize(&CUSTOM_ENGINE.decode(token).unwrap()).unwrap();
let v1 = macaroon.identifier();
let user_id = std::str::from_utf8(&v1.0).unwrap();
println!("identifier: {}", user_id);
println!("location: {:?}", macaroon.location());
println!("sig: {:?}", macaroon.signature());
let mut verifier = Verifier::default();
verifier.satisfy_general(verifier_callback);
// let openid_client = &services().globals.openid_client;
// let (key, _client) = openid_client.as_ref().unwrap();
// match verifier.verify(&macaroon, &key, Default::default()) {
// Ok(()) => println!("Macaroon verified!"),
// Err(error) => println!("Error validating macaroon: {:?}", error),
// }
let user_id =
UserId::parse_with_server_name(user_id, services().globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?;
println!("user_id: {}", user_id);
if !services().users.exists(&user_id)? {
let random_password = crate::utils::random_string(TOKEN_LENGTH);
services().users.create(&user_id, Some(&random_password))?;
services().account_data.update(
None,
&user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: push::Ruleset::server_default(&user_id),
},
})
.expect("to json always works"),
)?;
}
user_id
} else {
return Err(Error::BadRequest(
ErrorKind::Unknown,

View file

@ -0,0 +1,102 @@
use crate::{
service::sso::{templates, COOKIE_STATE_EXPIRATION_SECS},
services, Error, Ruma, RumaResponse,
};
use askama::Template;
use axum::{body::Full, response::IntoResponse};
use axum_extra::extract::cookie::{Cookie, SameSite};
use bytes::BytesMut;
use http::StatusCode;
use ruma::api::{
client::{error::ErrorKind, session},
OutgoingResponse,
};
/// # `GET /_matrix/client/v3/login/sso/redirect`
///
/// Redirect user to SSO interface. The path argument is optional.
pub async fn get_sso_redirect(
body: Ruma<session::sso_login::v3::Request>,
) -> axum::response::Response {
if services().sso.get_all().is_empty() {
return Error::BadRequest(ErrorKind::NotFound, "SSO has not been configured")
.into_response();
}
return get_sso_fallback_template(body.redirect_url.as_deref().unwrap_or_default()).into_response();
}
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
///
/// Redirect user to SSO interface.
pub async fn get_sso_redirect_with_provider(
// State(uiaa_session): State<Option<()>>,
body: Ruma<session::sso_login_with_provider::v3::Request>,
) -> axum::response::Response {
if services().sso.get_all().is_empty() {
return Error::BadRequest(ErrorKind::NotFound, "SSO has not been configured")
.into_response();
}
if body.idp_id.is_empty() {
return get_sso_fallback_template(body.redirect_url.as_deref().unwrap_or_default()).into_response();
};
let Some(provider) = services().sso.get_provider(&body.idp_id) else {
return Error::BadRequest(ErrorKind::NotFound, "Unknown identity provider").into_response();
};
let (location, cookie) = provider.handle_redirect(body.redirect_url.as_deref().unwrap_or_default()).await;
let cookie = Cookie::build("openid-state", cookie)
.path("/_conduit/client/sso")
// .secure(false) //FIXME
.secure(true)
.http_only(true)
.same_site(SameSite::None)
.max_age(time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
.finish()
.to_string();
let mut res = session::sso_login_with_provider::v3::Response {
location: Some(location.to_string()),
cookie: Some(cookie),
}
.try_into_http_response::<BytesMut>()
.unwrap();
*res.status_mut() = StatusCode::FOUND;
res.map(BytesMut::freeze).map(Full::new).into_response()
}
fn get_sso_fallback_template(redirect_url: &str) -> axum::response::Response {
let server_name = services().globals.server_name().to_string();
let metadata = services().sso.inner.iter().map(Into::into).collect();
let redirect_url = redirect_url.to_string();
let t = templates::IdpPicker {
server_name,
metadata,
redirect_url,
};
t.render()
.map(|body| {
((
[(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static(templates::IdpPicker::MIME_TYPE),
)],
body,
))
.into_response()
})
.expect("woops")
}
/// # `GET /_conduit/client/oidc/callback`
///
/// Verify the response received from the identity provider.
/// If everything is fine redirect
pub async fn get_sso_callback() {}

View file

@ -9,11 +9,11 @@ use serde::{de::IgnoredAny, Deserialize};
use tracing::warn;
mod proxy;
mod oidc;
mod sso;
pub use oidc::*;
use self::proxy::ProxyConfig;
use self::{oidc::OidcConfig, proxy::ProxyConfig};
pub use sso::*;
#[derive(Clone, Debug, Deserialize)]
pub struct Config {
@ -85,7 +85,7 @@ pub struct Config {
#[serde(default)]
pub macaroon_key: Option<String>,
#[serde(default)]
pub oidc: OidcConfig,
pub sso: Vec<ProviderConfig>,
pub emergency_password: Option<String>,

View file

@ -1,51 +1,18 @@
use openidconnect::JsonWebKeyId;
use ruma::{
api::client::session::get_login_types::v3::{IdentityProvider, IdentityProviderBrand},
OwnedMxcUri,
};
use ruma::OwnedMxcUri;
use serde::Deserialize;
pub type OidcConfig = Vec<ProviderConfig>;
#[derive(Clone, Debug, Deserialize)]
pub struct Metadata {
// Must be unique, used to distinguish OPs
#[serde(rename = "idp_id")]
pub id: String,
#[serde(rename = "idp_name")]
pub name: Option<String>,
#[serde(rename = "idp_icon")]
pub icon: Option<OwnedMxcUri>,
}
impl Into<IdentityProvider> for Metadata {
fn into(self) -> IdentityProvider {
let brand = match IdentityProviderBrand::from(self.id.clone()) {
IdentityProviderBrand::_Custom(_) => None,
brand => Some(brand),
};
IdentityProvider {
id: self.id.clone(),
name: self.name.unwrap_or(self.id),
icon: self.icon,
brand,
}
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct ProviderConfig {
pub id: String,
pub name: Option<String>,
pub icon: Option<OwnedMxcUri>,
// Information retrieved while creating the OpenID Application
pub client: ClientConfig,
// Information for displaying the OpenID Provider
#[serde(flatten)]
pub metadata: Metadata,
// Foo
// #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
pub issuer: url::Url,

View file

@ -24,7 +24,7 @@ use ruma::api::{
IncomingRequest,
};
use tokio::signal;
use tower::ServiceBuilder;
use tower::{ServiceBuilder, ServiceExt};
use tower_http::{
cors::{self, CorsLayer},
trace::TraceLayer,
@ -401,11 +401,22 @@ fn routes() -> Router {
"/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route),
)
// .route(
// "/_matrix/client/v3/login/sso/redirect",
// get(client_server::get_sso_redirect),
// )
.ruma_route(client_server::get_sso_redirect_with_idp_id)
.route(
"/_matrix/client/v3/login/sso/redirect",
get(client_server::get_sso_redirect),
)
.route(
"/_matrix/client/v3/login/sso/redirect/",
get(client_server::get_sso_redirect),
)
.route(
"/_matrix/client/v3/login/sso/redirect/:idp_id",
get(client_server::get_sso_redirect_with_provider),
)
.route(
"/_conduit/v3/login/sso/callback",
get(client_server::get_sso_callback),
)
// .route("/sso_return", get(client_server::get_sso_return))
.ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route)

View file

@ -13,7 +13,7 @@ pub mod appservice;
pub mod globals;
pub mod key_backups;
pub mod media;
pub mod oidc;
pub mod sso;
pub mod pdu;
pub mod pusher;
pub mod rooms;
@ -35,7 +35,7 @@ pub struct Services {
pub key_backups: key_backups::Service,
pub media: media::Service,
pub sending: Arc<sending::Service>,
pub oidc: Arc<oidc::Service>,
pub sso: Arc<sso::Service>,
}
impl Services {
@ -116,8 +116,7 @@ impl Services {
key_backups: key_backups::Service { db },
media: media::Service { db },
sending: sending::Service::build(db, &config),
oidc: oidc::Service::build(&config).await,
sso: sso::Service::build(&config).await,
globals: globals::Service::load(db, config).await?,
})
}

View file

@ -7,14 +7,14 @@ use openidconnect::{
AuthUrl, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, PkceCodeChallenge, RedirectUrl,
Scope, TokenUrl, UserInfoUrl,
};
use ruma::api::client::session::get_login_types::v3::{IdentityProvider, IdentityProviderBrand};
use tokio::sync::OnceCell;
use crate::{
config::{DiscoveryConfig as Discovery, Metadata, ProviderConfig},
services, Config, Error,
};
use crate::{config::{DiscoveryConfig as Discovery, ProviderConfig}, services, Config, Error};
pub const COOKIE_STATE_EXPIRATION_SECS: i64 = 10 * 60;
pub const COOKIE_STATE_EXPIRATION_SECS: i64 = 60 * 60;
pub mod templates;
#[derive(Clone)]
pub struct Client(Arc<CoreClient>);
@ -47,15 +47,11 @@ impl Provider {
.map(ToOwned::to_owned)
.map(Scope::new);
let csrf = CsrfToken::new_random();
let nonce = Nonce::new_random();
let tmp = (csrf.clone(), nonce.clone());
let mut req = client
.authorize_url(
CoreAuthenticationFlow::Implicit(true),
move || csrf,
move || nonce,
|| CsrfToken::new_random_len(36),
|| Nonce::new_random_len(36),
)
.add_scopes(scopes);
@ -88,10 +84,12 @@ impl Provider {
.macaroon
.unwrap_or_else(MacaroonKey::generate_random);
let mut macaroon = Macaroon::create(None, &key, "oidc".into()).unwrap();
let expires = chrono::Utc::now() + chrono::TimeDelta::seconds(COOKIE_STATE_EXPIRATION_SECS);
let mut macaroon = Macaroon::create(None, &key, "sso".into()).unwrap();
let expires = (time::OffsetDateTime::now_utc()
+ time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
.to_string();
let idp_id = self.config.metadata.id.as_str();
let idp_id = self.config.id.as_str();
macaroon.add_first_party_caveat(format!("idp_id = {idp_id}").into());
macaroon.add_first_party_caveat(format!("state = {state}").into());
@ -107,6 +105,23 @@ impl Provider {
}
}
impl Into<IdentityProvider> for ProviderConfig {
fn into(self) -> IdentityProvider {
let brand = match IdentityProviderBrand::from(self.id.clone()) {
IdentityProviderBrand::_Custom(_) => None,
brand => Some(brand),
};
IdentityProvider {
id: self.id.clone(),
name: self.name.unwrap_or(self.id),
icon: self.icon,
brand,
}
}
}
pub struct Service {
pub inner: Vec<Provider>,
}
@ -114,39 +129,22 @@ pub struct Service {
impl Service {
pub async fn build(config: &Config) -> Arc<Self> {
Arc::new(Self {
inner: config.oidc.clone().into_iter().map(Provider::new).collect(),
inner: config.sso.clone().into_iter().map(Provider::new).collect(),
})
}
pub async fn get_provider(&self, idp_id: impl AsRef<str>) -> Result<Provider, ()> {
let Some(found) = self
.inner
pub fn get_provider(&self, idp_id: impl AsRef<str>) -> Option<Provider> {
self.inner
.iter()
.find(|p| p.config.metadata.id == idp_id.as_ref())
.map(Clone::clone)
else {
return Err(());
};
// let client = found.client
// .get_or_try_init(|| async { Client::new(config.clone()).await })
// .await
// .unwrap();
Ok(found)
.find(|p| p.config.id == idp_id.as_ref())
.map(ToOwned::to_owned)
}
pub fn get_all(&self) -> &[Provider] {
self.inner.as_slice()
}
// TODO
pub fn get_metadata(&self) -> Vec<Metadata> {
self.inner
.iter()
.map(|p| p.config.metadata.clone())
.collect()
}
pub fn validate_session(&self) {}
// pub async fn generate_auth_url<'s, S>(&self, idp_id: String, scopes: S) -> ()
// where
@ -212,7 +210,7 @@ impl Client {
)
.expect("server_name should be a valid URL");
base_url.set_path("_conduit/client/oidc/callback");
base_url.set_path("_conduit/client/sso/callback");
let redirect_url = RedirectUrl::from_url(base_url);
let client = match config.discovery {

View file

@ -1,32 +1,50 @@
use askama::Template;
use ruma::{OwnedUserId, api::client::search::search_events::v3::UserProfile, OwnedServerName};
use ruma::{
api::client::search::search_events::v3::UserProfile, OwnedMxcUri, OwnedServerName, OwnedUserId,
};
use crate::config::Metadata;
use super::Provider;
#[derive(Template)]
#[template(path = "auth_confirmation.html", escape = "none")]
pub struct AuthConfirmationTemplate {
pub struct AuthConfirmation {
description: String,
redirect_url: url::Url,
idp_name: String,
}
pub struct Metadata {
id: String,
name: Option<String>,
icon: Option<OwnedMxcUri>,
}
impl From<&Provider> for Metadata {
fn from(value: &Provider) -> Self {
Self {
id: value.config.id.clone(),
name: value.config.name.clone(),
icon: value.config.icon.clone(),
}
}
}
#[derive(Template)]
#[template(path = "auth_failure.html", escape = "none")]
pub struct AuthFailureTemplate {
pub struct AuthFailure {
server_name: OwnedServerName,
}
#[derive(Template)]
#[template(path = "auth_success.html", escape = "none")]
pub struct AuthSuccessTemplate {}
pub struct AuthSuccess {}
#[derive(Template)]
#[template(path = "deactivated.html", escape = "none")]
pub struct DeactivatedTemplate {}
pub struct Deactivated {}
#[derive(Template)]
#[template(path = "idp_picker.html", escape = "none")]
pub struct IdpPickerTemplate {
pub struct IdpPicker {
pub server_name: String,
pub metadata: Vec<Metadata>,
pub redirect_url: String,
@ -34,7 +52,7 @@ pub struct IdpPickerTemplate {
#[derive(Template)]
#[template(path = "registration.html", escape = "none")]
pub struct RegistrationTemplate {
pub struct Registration {
pub server_name: OwnedServerName,
pub idp: Metadata,
pub user: Attributes,
@ -49,7 +67,7 @@ pub struct Attributes {
#[derive(Template)]
#[template(path = "redirect_confirm.html", escape = "none")]
pub struct RedirectConfirmTemplate {
pub struct RedirectConfirm {
pub user_id: OwnedUserId,
pub user_profile: UserProfile,
pub display_url: url::Url,