Skip to content

Commit

Permalink
Allow setting an explicit upstream account name
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Nov 29, 2024
1 parent 2c01b43 commit 5f47039
Show file tree
Hide file tree
Showing 21 changed files with 279 additions and 146 deletions.
3 changes: 2 additions & 1 deletion crates/cli/src/commands/manage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,10 @@ impl UserCreationRequest<'_> {
}

for (provider, subject) in upstream_provider_mappings {
// Note that we don't pass a human_account_name here, as we don't ask for it
let link = repo
.upstream_oauth_link()
.add(rng, clock, provider, subject)
.add(rng, clock, provider, subject, None)
.await?;

repo.upstream_oauth_link()
Expand Down
3 changes: 3 additions & 0 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ fn map_claims_imports(
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Import
}
},
account_name: mas_data_model::UpstreamOAuthProviderSubjectPreference {
template: config.account_name.template.clone(),
},
}
}

Expand Down
24 changes: 24 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ impl EmailImportPreference {
}
}

/// What should be done for the account name attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct AccountNameImportPreference {
/// The Jinja2 template to use for the account name. This name is only used
/// for display purposes.
///
/// If not provided, it will be ignored.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}

impl AccountNameImportPreference {
const fn is_default(&self) -> bool {
self.template.is_none()
}
}

/// How claims should be imported
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct ClaimsImports {
Expand All @@ -307,6 +324,13 @@ pub struct ClaimsImports {
/// `email_verified` claims
#[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
pub email: EmailImportPreference,

/// Set a human-readable name for the upstream account for display purposes
#[serde(
default,
skip_serializing_if = "AccountNameImportPreference::is_default"
)]
pub account_name: AccountNameImportPreference,
}

impl ClaimsImports {
Expand Down
1 change: 1 addition & 0 deletions crates/data-model/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ pub struct UpstreamOAuthLink {
pub provider_id: Ulid,
pub user_id: Option<Ulid>,
pub subject: String,
pub human_account_name: Option<String>,
pub created_at: DateTime<Utc>,
}
4 changes: 4 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,14 @@ pub struct ClaimsImports {
#[serde(default)]
pub email: ImportPreference,

#[serde(default)]
pub account_name: SubjectPreference,

#[serde(default)]
pub verify_email: SetEmailVerification,
}

// XXX: this should have another name
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SubjectPreference {
#[serde(default)]
Expand Down
22 changes: 20 additions & 2 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ pub(crate) async fn handler(
.as_deref()
.unwrap_or("{{ user.sub }}");
let subject = env
.render_str(template, context)
.render_str(template, context.clone())
.map_err(RouteError::ExtractSubject)?;

if subject.is_empty() {
Expand All @@ -375,8 +375,26 @@ pub(crate) async fn handler(
let link = if let Some(link) = maybe_link {
link
} else {
// Try to render the human account name if we have one,
// but just log if it fails
let human_account_name = provider
.claims_imports
.account_name
.template
.as_deref()
.and_then(|template| match env.render_str(template, context) {
Ok(name) => Some(name),
Err(e) => {
tracing::warn!(
error = &e as &dyn std::error::Error,
"Failed to render account name"
);
None
}
});

repo.upstream_oauth_link()
.add(&mut rng, &clock, &provider, subject)
.add(&mut rng, &clock, &provider, subject, human_account_name)
.await?
};

Expand Down
12 changes: 9 additions & 3 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ pub(crate) async fn get(
.await?
.ok_or(RouteError::ProviderNotFound)?;

let ctx = UpstreamRegister::default();
let ctx = UpstreamRegister::new(link.clone(), provider.clone());

let env = environment();

Expand Down Expand Up @@ -596,7 +596,7 @@ pub(crate) async fn post(
.map_or(false, |v| v == "true");

// Create a template context in case we need to re-render because of an error
let ctx = UpstreamRegister::default();
let ctx = UpstreamRegister::new(link.clone(), provider.clone());

let display_name = if provider
.claims_imports
Expand Down Expand Up @@ -954,7 +954,13 @@ mod tests {

let link = repo
.upstream_oauth_link()
.add(&mut rng, &state.clock, &provider, "subject".to_owned())
.add(
&mut rng,
&state.clock,
&provider,
"subject".to_owned(),
None,
)
.await
.unwrap();

Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions crates/storage-pg/src/iden.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,6 @@ pub enum UpstreamOAuthLinks {
UpstreamOAuthProviderId,
UserId,
Subject,
HumanAccountName,
CreatedAt,
}
18 changes: 17 additions & 1 deletion crates/storage-pg/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct LinkLookup {
upstream_oauth_provider_id: Uuid,
user_id: Option<Uuid>,
subject: String,
human_account_name: Option<String>,
created_at: DateTime<Utc>,
}

Expand All @@ -57,6 +58,7 @@ impl From<LinkLookup> for UpstreamOAuthLink {
provider_id: Ulid::from(value.upstream_oauth_provider_id),
user_id: value.user_id.map(Ulid::from),
subject: value.subject,
human_account_name: value.human_account_name,
created_at: value.created_at,
}
}
Expand Down Expand Up @@ -124,6 +126,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
upstream_oauth_provider_id,
user_id,
subject,
human_account_name,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_link_id = $1
Expand Down Expand Up @@ -163,6 +166,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
upstream_oauth_provider_id,
user_id,
subject,
human_account_name,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_provider_id = $1
Expand All @@ -186,6 +190,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
db.query.text,
upstream_oauth_link.id,
upstream_oauth_link.subject = subject,
upstream_oauth_link.human_account_name = human_account_name,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
Expand All @@ -198,6 +203,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
human_account_name: Option<String>,
) -> Result<UpstreamOAuthLink, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
Expand All @@ -210,12 +216,14 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
upstream_oauth_provider_id,
user_id,
subject,
human_account_name,
created_at
) VALUES ($1, $2, NULL, $3, $4)
) VALUES ($1, $2, NULL, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&subject,
human_account_name.as_deref(),
created_at,
)
.traced()
Expand All @@ -227,6 +235,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
provider_id: upstream_oauth_provider.id,
user_id: None,
subject,
human_account_name,
created_at,
})
}
Expand Down Expand Up @@ -300,6 +309,13 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
LinkLookupIden::Subject,
)
.expr_as(
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::HumanAccountName,
)),
LinkLookupIden::HumanAccountName,
)
.expr_as(
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
LinkLookupIden::CreatedAt,
Expand Down
2 changes: 1 addition & 1 deletion crates/storage-pg/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ mod tests {
// Create a link
let link = repo
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
.add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
.await
.unwrap();

Expand Down
3 changes: 3 additions & 0 deletions crates/storage/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
/// * `upsream_oauth_provider`: The upstream OAuth provider for which to
/// create the link
/// * `subject`: The subject of the upstream OAuth link to create
/// * `human_account_name`: A human-readable name for the upstream account
///
/// # Errors
///
Expand All @@ -138,6 +139,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
human_account_name: Option<String>,
) -> Result<UpstreamOAuthLink, Self::Error>;

/// Associate an upstream OAuth link to a user
Expand Down Expand Up @@ -201,6 +203,7 @@ repository_impl!(UpstreamOAuthLinkRepository:
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
human_account_name: Option<String>,
) -> Result<UpstreamOAuthLink, Self::Error>;

async fn associate_to_user(
Expand Down
Loading

0 comments on commit 5f47039

Please sign in to comment.