Skip to content

Commit

Permalink
Add support for wildcard '*' in cors domains (#563)
Browse files Browse the repository at this point in the history
* added the wildcard crate.
Added the functionality to specify wildcard '*'  in cors domains

* added test

* Corrected bug

---------

Co-authored-by: Samuel Batissou <samuel.batissou@thinkeo.io>
  • Loading branch information
Sagebati and Sagebati authored Nov 19, 2023
1 parent 1643d7a commit 80d31ca
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
1 change: 1 addition & 0 deletions poem/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ headers = "0.3.7"
thiserror.workspace = true
rfc7239 = "0.1.0"
mime.workspace = true
wildmatch = "2"

# Non-feature optional dependencies
multer = { version = "2.1.0", features = ["tokio"], optional = true }
Expand Down
59 changes: 58 additions & 1 deletion poem/src/middleware/cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{collections::HashSet, str::FromStr, sync::Arc};
use headers::{
AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMapExt,
};
use wildmatch::WildMatch;

use crate::{
endpoint::Endpoint,
Expand Down Expand Up @@ -39,6 +40,7 @@ use crate::{
pub struct Cors {
allow_credentials: bool,
allow_origins: HashSet<HeaderValue>,
allow_origins_wildcard: Vec<WildMatch>,
allow_origins_fn: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
allow_headers: HashSet<HeaderName>,
allow_methods: HashSet<Method>,
Expand Down Expand Up @@ -135,6 +137,14 @@ impl Cors {
self
}

/// Add an allowed origin that supports '*' wildcard.
/// Example: `rust cors.allow_origin_regex("https://*.domain.url")`
fn allow_origin_regex(mut self, origin: impl AsRef<str>) -> Self {
self.allow_origins_wildcard
.push(WildMatch::new(origin.as_ref()));
self
}

/// Add many allow origins.
#[must_use]
pub fn allow_origins<I, T>(self, origins: I) -> Self
Expand Down Expand Up @@ -203,6 +213,7 @@ impl<E: Endpoint> Middleware<E> for Cors {
inner: ep,
allow_credentials: self.allow_credentials,
allow_origins: self.allow_origins.clone(),
allow_origins_wildcard: self.allow_origins_wildcard.clone(),
allow_origins_fn: self.allow_origins_fn.clone(),
allow_headers: self.allow_headers.clone(),
allow_methods: self.allow_methods.clone(),
Expand All @@ -221,6 +232,7 @@ pub struct CorsEndpoint<E> {
inner: E,
allow_credentials: bool,
allow_origins: HashSet<HeaderValue>,
allow_origins_wildcard: Vec<WildMatch>,
allow_origins_fn: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
allow_headers: HashSet<HeaderName>,
allow_methods: HashSet<Method>,
Expand All @@ -237,6 +249,14 @@ impl<E: Endpoint> CorsEndpoint<E> {
return (true, false);
}

if self
.allow_origins_wildcard
.iter()
.any(|m| m.matches(origin.to_str().unwrap()))
{
return (true, true);
}

if let Some(allow_origins_fn) = &self.allow_origins_fn {
if let Ok(origin) = origin.to_str() {
if allow_origins_fn(origin) {
Expand All @@ -246,7 +266,9 @@ impl<E: Endpoint> CorsEndpoint<E> {
}

(
self.allow_origins.is_empty() && self.allow_origins_fn.is_none(),
self.allow_origins.is_empty()
&& self.allow_origins_fn.is_none()
&& self.allow_origins_wildcard.is_empty(),
true,
)
}
Expand Down Expand Up @@ -541,6 +563,41 @@ mod tests {
resp.assert_status(StatusCode::FORBIDDEN);
}

#[tokio::test]
async fn allow_origins_fn_4() {
let ep =
make_sync(|_| "hello").with(Cors::new().allow_origin_regex("https://*example.com"));
let cli = TestClient::new(ep);

let resp = cli
.get("/")
.header(header::ORIGIN, "https://example.mx")
.send()
.await;
resp.assert_status(StatusCode::FORBIDDEN);

let resp = cli
.get("/")
.header(header::ORIGIN, "https://test.example.com")
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
"https://test.example.com",
);
resp.assert_header_is_not_exist(header::VARY);

let resp = cli
.get("/")
.header(header::ORIGIN, "https://example.com")
.send()
.await;
resp.assert_status_is_ok();
resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com");
resp.assert_header(header::VARY, "Origin");
}

#[tokio::test]
async fn default_cors_middleware() {
let ep = make_sync(|_| "hello").with(Cors::new());
Expand Down

0 comments on commit 80d31ca

Please sign in to comment.