Draft: SSO login (OAuth 2.0 + OpenID Connect) #1012
2 changed files with 267 additions and 192 deletions
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
service::sso::{templates, COOKIE_STATE_EXPIRATION_SECS},
|
||||
service::sso::{templates, Provider, COOKIE_STATE_EXPIRATION_SECS},
|
||||
services, Error, Ruma, RumaResponse,
|
||||
};
|
||||
use askama::Template;
|
||||
|
@ -7,10 +7,14 @@ use axum::{body::Full, response::IntoResponse};
|
|||
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||||
use bytes::BytesMut;
|
||||
use http::StatusCode;
|
||||
use macaroon::ByteString;
|
||||
use openidconnect::{reqwest::{http_client, async_http_client}, AuthorizationCode, CsrfToken, TokenResponse};
|
||||
use ruma::api::{
|
||||
client::{error::ErrorKind, session},
|
||||
OutgoingResponse,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use time::macros::format_description;
|
||||
|
||||
/// # `GET /_matrix/client/v3/login/sso/redirect`
|
||||
///
|
||||
|
@ -23,7 +27,8 @@ pub async fn get_sso_redirect(
|
|||
.into_response();
|
||||
}
|
||||
|
||||
return get_sso_fallback_template(body.redirect_url.as_deref().unwrap_or_default()).into_response();
|
||||
return get_sso_fallback_template(body.redirect_url.as_deref().unwrap_or_default())
|
||||
.into_response();
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
|
||||
|
@ -39,14 +44,17 @@ pub async fn get_sso_redirect_with_provider(
|
|||
}
|
||||
|
||||
if body.idp_id.is_empty() {
|
||||
return get_sso_fallback_template(body.redirect_url.as_deref().unwrap_or_default()).into_response();
|
||||
return get_sso_fallback_template(body.redirect_url.as_deref().unwrap_or_default())
|
||||
.into_response();
|
||||
};
|
||||
|
||||
let Some(provider) = services().sso.get_provider(&body.idp_id) else {
|
||||
return Error::BadRequest(ErrorKind::NotFound, "Unknown identity provider").into_response();
|
||||
let location = Some(body.idp_id.clone());
|
||||
|
||||
let (url, nonce, cookie) = match services().sso.find_one(&body.idp_id).map(|provider| provider.handle_redirect(body.redirect_url.as_deref().unwrap_or_default())) {
|
||||
Ok(fut)=> fut.await,
|
||||
Err(e)=> return e.into_response(),
|
||||
};
|
||||
|
||||
let (location, cookie) = provider.handle_redirect(body.redirect_url.as_deref().unwrap_or_default()).await;
|
||||
|
||||
let cookie = Cookie::build("openid-state", cookie)
|
||||
.path("/_conduit/client/sso")
|
||||
|
@ -59,7 +67,7 @@ pub async fn get_sso_redirect_with_provider(
|
|||
.to_string();
|
||||
|
||||
let mut res = session::sso_login_with_provider::v3::Response {
|
||||
location: Some(location.to_string()),
|
||||
location,
|
||||
cookie: Some(cookie),
|
||||
}
|
||||
.try_into_http_response::<BytesMut>()
|
||||
|
@ -95,8 +103,52 @@ fn get_sso_fallback_template(redirect_url: &str) -> axum::response::Response {
|
|||
.expect("woops")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Callback {
|
||||
pub code: AuthorizationCode,
|
||||
pub state: CsrfToken,
|
||||
}
|
||||
|
||||
/// # `GET /_conduit/client/oidc/callback`
|
||||
///
|
||||
/// Verify the response received from the identity provider.
|
||||
/// If everything is fine redirect
|
||||
pub async fn get_sso_callback() {}
|
||||
pub async fn get_sso_callback(
|
||||
cookie: axum::extract::TypedHeader<axum::headers::Cookie>,
|
||||
axum::extract::Query(callback): axum::extract::Query<Callback>,
|
||||
) -> axum::response::Response {
|
||||
// TODO
|
||||
|
||||
let Callback { code, state } = callback;
|
||||
|
||||
let Some(cookie) = cookie.get("openid-state") else {
|
||||
return Error::BadRequest(
|
||||
ErrorKind::MissingToken,
|
||||
"Could not retrieve SSO macaroon from cookie",
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
let provider = match Provider::verify_macaroon(cookie.as_bytes(), state)
|
||||
.and_then(|macaroon| services().sso.find_one(macaroon.identifier().into()))
|
||||
{
|
||||
Ok(provider) => provider,
|
||||
Err(error) => return error.into_response(),
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
let cookie = Cookie::build("openid-state", "")
|
||||
.path("/_conduit/client/sso")
|
||||
.finish()
|
||||
.to_string();
|
||||
|
||||
let user_info = provider.handle_callback(code, nonce);
|
||||
|
||||
// if let Some(verifier) = pkce {
|
||||
// macaroon.add_first_party_caveat(format!("verifier = {}", verifier).into());
|
||||
// }
|
||||
|
||||
(TypedHeader(ContentType::text_utf8()), "Hello, World!").into_response()
|
||||
}
|
||||
|
|
|
@ -1,127 +1,26 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use macaroon::{Macaroon, MacaroonKey};
|
||||
use futures_util::future::{self};
|
||||
use macaroon::{Macaroon, Verifier};
|
||||
use openidconnect::{
|
||||
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
|
||||
core::{CoreAuthenticationFlow, CoreClient, CoreGenderClaim, CoreProviderMetadata},
|
||||
reqwest::async_http_client,
|
||||
AuthUrl, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, PkceCodeChallenge, RedirectUrl,
|
||||
Scope, TokenUrl, UserInfoUrl,
|
||||
AccessTokenHash, AdditionalClaims, AuthUrl, AuthorizationCode, ClientId, ClientSecret,
|
||||
CsrfToken, IssuerUrl, Nonce, NonceVerifier, OAuth2TokenResponse, PkceCodeChallenge,
|
||||
RedirectUrl, Scope, SubjectIdentifier, TokenResponse, TokenUrl, UserInfoClaims, UserInfoUrl,
|
||||
};
|
||||
use ruma::api::client::session::get_login_types::v3::{IdentityProvider, IdentityProviderBrand};
|
||||
use tokio::sync::OnceCell;
|
||||
use ruma::api::client::{error::ErrorKind, session::get_login_types::v3::IdentityProvider};
|
||||
use time::macros::format_description;
|
||||
|
||||
use crate::{config::{DiscoveryConfig as Discovery, ProviderConfig}, services, Config, Error};
|
||||
use crate::{
|
||||
config::{ClientConfig, DiscoveryConfig as Discovery, ProviderConfig},
|
||||
services, Config, Error,
|
||||
};
|
||||
|
||||
pub const COOKIE_STATE_EXPIRATION_SECS: i64 = 60 * 60;
|
||||
|
||||
pub mod templates;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client(Arc<CoreClient>);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Provider {
|
||||
pub config: ProviderConfig,
|
||||
pub client: OnceCell<Client>,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
pub fn new(config: ProviderConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
client: OnceCell::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_redirect(&self, redirect_url: &str) -> (url::Url, String) {
|
||||
let client = self
|
||||
.client
|
||||
.get_or_try_init(|| async { Client::new(self.config.clone()).await })
|
||||
.await
|
||||
.map(|c| c.0.clone())
|
||||
.unwrap();
|
||||
let scopes = self
|
||||
.config
|
||||
.scopes
|
||||
.iter()
|
||||
.map(ToOwned::to_owned)
|
||||
.map(Scope::new);
|
||||
|
||||
let mut req = client
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::Implicit(true),
|
||||
|| CsrfToken::new_random_len(36),
|
||||
|| Nonce::new_random_len(36),
|
||||
)
|
||||
.add_scopes(scopes);
|
||||
|
||||
let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
if let Some(true) = self.config.pkce {
|
||||
req = req.set_pkce_challenge(challenge);
|
||||
}
|
||||
|
||||
let (url, csrf, nonce) = req.url();
|
||||
|
||||
let cookie = self.generate_macaroon(
|
||||
csrf.secret(),
|
||||
nonce.secret(),
|
||||
redirect_url,
|
||||
self.config.pkce.map(|_| verifier.secret().as_str()),
|
||||
);
|
||||
|
||||
(url, cookie)
|
||||
}
|
||||
|
||||
pub fn generate_macaroon(
|
||||
&self,
|
||||
state: &str,
|
||||
nonce: &str,
|
||||
redirect_url: &str,
|
||||
pkce: Option<&str>,
|
||||
) -> String {
|
||||
let key = services()
|
||||
.globals
|
||||
.macaroon
|
||||
.unwrap_or_else(MacaroonKey::generate_random);
|
||||
|
||||
let mut macaroon = Macaroon::create(None, &key, "sso".into()).unwrap();
|
||||
let expires = (time::OffsetDateTime::now_utc()
|
||||
+ time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
|
||||
.to_string();
|
||||
|
||||
let idp_id = self.config.id.as_str();
|
||||
|
||||
macaroon.add_first_party_caveat(format!("idp_id = {idp_id}").into());
|
||||
macaroon.add_first_party_caveat(format!("state = {state}").into());
|
||||
macaroon.add_first_party_caveat(format!("nonce = {nonce}").into());
|
||||
macaroon.add_first_party_caveat(format!("redirect_url = {redirect_url}").into());
|
||||
macaroon.add_first_party_caveat(format!("time < {expires}").into());
|
||||
|
||||
if let Some(verifier) = pkce {
|
||||
macaroon.add_first_party_caveat(format!("verifier = {}", verifier).into());
|
||||
}
|
||||
|
||||
macaroon.serialize(macaroon::Format::V2).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<IdentityProvider> for ProviderConfig {
|
||||
fn into(self) -> IdentityProvider {
|
||||
let brand = match IdentityProviderBrand::from(self.id.clone()) {
|
||||
IdentityProviderBrand::_Custom(_) => None,
|
||||
brand => Some(brand),
|
||||
};
|
||||
|
||||
IdentityProvider {
|
||||
id: self.id.clone(),
|
||||
name: self.name.unwrap_or(self.id),
|
||||
icon: self.icon,
|
||||
brand,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct Service {
|
||||
pub inner: Vec<Provider>,
|
||||
}
|
||||
|
@ -129,78 +28,59 @@ pub struct Service {
|
|||
impl Service {
|
||||
pub async fn build(config: &Config) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
inner: config.sso.clone().into_iter().map(Provider::new).collect(),
|
||||
inner: future::join_all(config.sso.clone().into_iter().map(Provider::new)).await,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_provider(&self, idp_id: impl AsRef<str>) -> Option<Provider> {
|
||||
self.inner
|
||||
.iter()
|
||||
.find(|p| p.config.id == idp_id.as_ref())
|
||||
.map(ToOwned::to_owned)
|
||||
pub fn find_one(&self, idp_id: impl AsRef<str>) -> Result<Provider, Error> {
|
||||
match self.inner.iter().find(|p| p.inner.id == idp_id.as_ref()) {
|
||||
Some(provider) => Ok(provider.to_owned()),
|
||||
None => Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"unknown identity provider",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_all(&self) -> &[Provider] {
|
||||
self.inner.as_slice()
|
||||
}
|
||||
|
||||
pub fn validate_session(&self) {}
|
||||
|
||||
// pub async fn generate_auth_url<'s, S>(&self, idp_id: String, scopes: S) -> ()
|
||||
// where
|
||||
// S: Iterator<Item = &'s str>,
|
||||
// {
|
||||
|
||||
// // TODO: PKCE challenge
|
||||
// if false {
|
||||
// let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
// // return auth_url.set_pkce_challenge(challenge).url()
|
||||
// }
|
||||
|
||||
// // auth_url.url()
|
||||
// }
|
||||
|
||||
// pub async fn exchange_auth_code(&self, auth_code: AuthorizationCode) {
|
||||
// // -> Result<StandardTokenResponse, Error> {
|
||||
// let provider = self.get_client("something").await.unwrap();
|
||||
|
||||
// let req = provider.exchange_code(auth_code);
|
||||
// if false {
|
||||
// // return req.set_pkce_verifier(pkce_verifier).request(http_client)?
|
||||
// }
|
||||
|
||||
// let resp = req.request(http_client).unwrap();
|
||||
|
||||
// let id_token = resp.id_token().unwrap();
|
||||
|
||||
// let access_token = resp.access_token();
|
||||
|
||||
// let claims = id_token
|
||||
// .claims(&provider.id_token_verifier(), &Nonce::new("".into()))
|
||||
// .unwrap();
|
||||
|
||||
// if let Some(hash) = claims.access_token_hash() {
|
||||
// &AccessTokenHash::from_token(access_token, &id_token.signing_alg().unwrap()).unwrap()
|
||||
// == hash;
|
||||
// }
|
||||
|
||||
// // need `UserInfo` endpoint
|
||||
// // if let Some(subject) = config.subject_claim {
|
||||
// if false {
|
||||
// let req = provider
|
||||
// .user_info(
|
||||
// access_token.clone(),
|
||||
// Some(SubjectIdentifier::new("id".to_owned())),
|
||||
// )
|
||||
// .unwrap();
|
||||
// let ok: CoreUserInfoClaims = req.request(http_client).unwrap();
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub async fn new(config: ProviderConfig) -> Result<Self, Error> {
|
||||
#[derive(Clone)]
|
||||
pub struct Provider {
|
||||
pub inner: IdentityProvider,
|
||||
pub client: Arc<CoreClient>,
|
||||
pub scopes: Vec<String>,
|
||||
pub pkce: Option<bool>,
|
||||
pub subject_claim: Option<String>,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
pub async fn new(config: ProviderConfig) -> Self {
|
||||
let inner = IdentityProvider {
|
||||
id: config.id.clone(),
|
||||
name: config.name.unwrap_or(config.id),
|
||||
icon: config.icon,
|
||||
brand: None,
|
||||
};
|
||||
|
||||
Self {
|
||||
inner,
|
||||
client: Provider::create_client(config.discovery, config.issuer, config.client)
|
||||
.await
|
||||
.unwrap(),
|
||||
scopes: config.scopes,
|
||||
pkce: config.pkce,
|
||||
subject_claim: config.subject_claim,
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_client(
|
||||
discovery: Discovery,
|
||||
issuer: url::Url,
|
||||
config: ClientConfig,
|
||||
) -> Result<Arc<CoreClient>, Error> {
|
||||
let mut base_url = url::Url::try_from(
|
||||
services()
|
||||
.globals
|
||||
|
@ -210,12 +90,12 @@ impl Client {
|
|||
)
|
||||
.expect("server_name should be a valid URL");
|
||||
|
||||
base_url.set_path("_conduit/client/sso/callback");
|
||||
base_url.set_path("_conduit/config/sso/callback");
|
||||
let redirect_url = RedirectUrl::from_url(base_url);
|
||||
|
||||
let client = match config.discovery {
|
||||
let config = match discovery {
|
||||
Discovery::Automatic => {
|
||||
let url = config.issuer.to_string();
|
||||
let url = issuer.to_string();
|
||||
let url = url.strip_suffix("/").unwrap();
|
||||
|
||||
let discovery = CoreProviderMetadata::discover_async(
|
||||
|
@ -229,14 +109,14 @@ impl Client {
|
|||
|
||||
CoreClient::from_provider_metadata(
|
||||
discovery,
|
||||
ClientId::new(config.client.id),
|
||||
config.client.secret.map(ClientSecret::new),
|
||||
ClientId::new(config.id),
|
||||
config.secret.map(ClientSecret::new),
|
||||
)
|
||||
}
|
||||
Discovery::Manual(endpoints) => CoreClient::new(
|
||||
ClientId::new(config.client.id),
|
||||
config.client.secret.map(ClientSecret::new),
|
||||
IssuerUrl::from_url(config.issuer),
|
||||
ClientId::new(config.id),
|
||||
config.secret.map(ClientSecret::new),
|
||||
IssuerUrl::from_url(issuer),
|
||||
AuthUrl::from_url(endpoints.auth),
|
||||
endpoints.token.map(TokenUrl::from_url),
|
||||
endpoints.userinfo.map(UserInfoUrl::from_url),
|
||||
|
@ -245,6 +125,149 @@ impl Client {
|
|||
.set_redirect_uri(redirect_url),
|
||||
};
|
||||
|
||||
Ok(Self(Arc::new(client)))
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
|
||||
pub async fn handle_redirect(&self, redirect_url: &str) -> (url::Url, String, String) {
|
||||
let client = self.client.clone();
|
||||
let scopes = self.scopes.iter().map(ToOwned::to_owned).map(Scope::new);
|
||||
|
||||
let mut req = client
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::Implicit(true),
|
||||
|| CsrfToken::new_random_len(36),
|
||||
|| Nonce::new_random_len(36),
|
||||
)
|
||||
.add_scopes(scopes);
|
||||
|
||||
let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
if let Some(true) = self.pkce {
|
||||
req = req.set_pkce_challenge(challenge);
|
||||
}
|
||||
|
||||
let (url, csrf, nonce) = req.url();
|
||||
|
||||
let cookie = self.generate_macaroon(
|
||||
self.inner.id.as_str(),
|
||||
csrf.secret(),
|
||||
nonce.secret(),
|
||||
redirect_url,
|
||||
self.pkce.map(|_| verifier.secret().as_str()),
|
||||
);
|
||||
|
||||
(url, nonce.secret().to_owned(), cookie)
|
||||
}
|
||||
|
||||
pub fn generate_macaroon(
|
||||
&self,
|
||||
idp_id: &str,
|
||||
state: &str,
|
||||
nonce: &str,
|
||||
redirect_url: &str,
|
||||
pkce: Option<&str>,
|
||||
) -> String {
|
||||
let key = services().globals.macaroon.unwrap();
|
||||
|
||||
let mut macaroon = Macaroon::create(None, &key, idp_id.into()).unwrap();
|
||||
let expires = (time::OffsetDateTime::now_utc()
|
||||
+ time::Duration::seconds(COOKIE_STATE_EXPIRATION_SECS))
|
||||
.to_string();
|
||||
|
||||
let idp_id = self.inner.id.as_str();
|
||||
|
||||
for caveat in [
|
||||
format!("idp_id = {idp_id}"),
|
||||
format!("state = {state}"),
|
||||
format!("nonce = {nonce}"),
|
||||
format!("redirect_url = {redirect_url}"),
|
||||
format!("time < {expires}"),
|
||||
] {
|
||||
macaroon.add_first_party_caveat(caveat.into());
|
||||
}
|
||||
|
||||
if let Some(verifier) = pkce {
|
||||
macaroon.add_first_party_caveat(format!("verifier = {}", verifier).into());
|
||||
}
|
||||
|
||||
macaroon.serialize(macaroon::Format::V2).unwrap()
|
||||
}
|
||||
|
||||
pub fn verify_macaroon(cookie: &[u8], state: CsrfToken) -> Result<Macaroon, Error> {
|
||||
let mut verifier = Verifier::default();
|
||||
|
||||
let macaroon = Macaroon::deserialize(cookie).map_err(|e| {
|
||||
Error::BadRequest(ErrorKind::BadJson, "Could not deserialize SSO macaroon")
|
||||
})?;
|
||||
|
||||
verifier.satisfy_exact(format!("state = {}", state.secret()).into());
|
||||
|
||||
// let verification = |s: &ByteString, id: &str| {
|
||||
// s.0.starts_with(format!("{id} =").as_bytes()); // TODO
|
||||
// };
|
||||
|
||||
verifier.satisfy_general(|s| s.0.starts_with(b"idp_id ="));
|
||||
verifier.satisfy_general(|s| s.0.starts_with(b"nonce ="));
|
||||
verifier.satisfy_general(|s| s.0.starts_with(b"redirect_url ="));
|
||||
|
||||
verifier.satisfy_general(|s| {
|
||||
let format_desc = format_description!(
|
||||
"[year]-[month]-[day] [hour]:[minute]:[second] [offset_hour \
|
||||
sign:mandatory]:[offset_minute]:[offset_second]"
|
||||
);
|
||||
|
||||
let now = time::OffsetDateTime::now_utc();
|
||||
|
||||
time::OffsetDateTime::parse(std::str::from_utf8(&s.0).unwrap(), format_desc)
|
||||
.map(|expires| now < expires)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
let key = services().globals.macaroon.unwrap();
|
||||
|
||||
verifier
|
||||
.verify(&macaroon, &key, Default::default())
|
||||
.map_err(|e| {
|
||||
Error::BadRequest(ErrorKind::Unauthorized, "Macaroon verification failed")
|
||||
})?;
|
||||
|
||||
Ok(macaroon)
|
||||
}
|
||||
|
||||
pub async fn handle_callback<Claims: AdditionalClaims>(
|
||||
&self,
|
||||
code: AuthorizationCode,
|
||||
nonce: Nonce,
|
||||
) -> Result<UserInfoClaims<Claims, CoreGenderClaim>, Error> {
|
||||
let resp = self
|
||||
.client
|
||||
.exchange_code(code)
|
||||
.request_async(async_http_client)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let id_token = resp.id_token().unwrap();
|
||||
let claims = id_token
|
||||
.claims(&self.client.id_token_verifier(), &nonce)
|
||||
.unwrap();
|
||||
|
||||
if let Some(expected) = claims.access_token_hash() {
|
||||
let found =
|
||||
AccessTokenHash::from_token(resp.access_token(), &id_token.signing_alg().unwrap())
|
||||
.unwrap();
|
||||
|
||||
if &found != expected {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
let Ok(req) = self.client.user_info(
|
||||
resp.access_token().to_owned(),
|
||||
self.subject_claim.clone().map(SubjectIdentifier::new),
|
||||
) else {
|
||||
resp.extra_fields();
|
||||
panic!()
|
||||
};
|
||||
|
||||
Ok(req.request_async(async_http_client).await.unwrap())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue