Skip to content

Commit

Permalink
Client authentication on token endpoint for all grant types
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjwwalker committed Nov 22, 2020
1 parent b0aed71 commit 82cf2a4
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 355 deletions.
97 changes: 80 additions & 17 deletions app/controllers/actions/BasicAuthAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,101 @@ trait BasicAuthAction {

implicit val ec: ExC

private type ClientAction = Request[AnyContent] => RegisteredApplication => Future[Result]
private type ClientAction = Request[AnyContent] => RegisteredApplication => Option[String] => Future[Result]

private def noAuthHeader = {
logger.warn(s"[clientAuthentication] - No auth header found in the request")
Future.successful(Unauthorized(StandardErrors.INVALID_REQUEST))
def clientAuthenticationOptionalPkce(f: ClientAction)(implicit ec: ExC): Action[AnyContent] = Action.async { implicit req =>
val clientCreds = getClientCredFromHeader.fold(getClientCredFromBody)(Some(_))
val pkceVerifier = getPkceCodeVerifier
clientCreds -> pkceVerifier match {
case (None, Some((id, verifier))) => clientService.getRegisteredAppById(id) flatMap {
case Some(app) =>
val decodedApp = RegisteredApplication.decode(app)
logger.info(s"[clientAuthentication] - Matched client with clientId $id")
f(req)(decodedApp)(Some(verifier))
case None =>
logger.warn(s"[clientAuthentication] - No client has been found matching clientId $id")
Future.successful(Unauthorized(StandardErrors.INVALID_CLIENT))
}
case (Some((id, sec)), None) => clientService.getRegisteredAppByIdAndSecret(id, sec) flatMap {
case Some(app) =>
val decodedApp = RegisteredApplication.decode(app)
logger.info(s"[clientAuthentication] - Matched client with clientId $id")
f(req)(decodedApp)(None)
case None =>
logger.warn(s"[clientAuthentication] - No client has been found matching clientId $id")
Future.successful(Unauthorized(StandardErrors.INVALID_CLIENT))
}
}
}

def clientAuthentication[T](f: ClientAction)(implicit ec: ExC): Action[AnyContent] = Action.async { implicit req =>
req.headers.get("Authorization").fold(noAuthHeader) { auth =>
def clientAuthentication(f: ClientAction)(implicit ec: ExC): Action[AnyContent] = Action.async { implicit req =>
val clientCreds = getClientCredFromHeader.fold(getClientCredFromBody)(x => Some(x))
clientCreds match {
case Some((id, sec)) => clientService.getRegisteredAppByIdAndSecret(id, sec) flatMap {
case Some(app) =>
val decodedApp = RegisteredApplication.decode(app)
logger.info(s"[clientAuthentication] - Matched client with clientId $id")
f(req)(decodedApp)(None)
case None =>
logger.warn(s"[clientAuthentication] - No client has been found matching clientId $id")
Future.successful(Unauthorized(StandardErrors.INVALID_CLIENT))
}
case None =>
Future.successful(Unauthorized(StandardErrors.INVALID_CLIENT))
}
}

private def getClientCredFromBody(implicit req: Request[AnyContent]): Option[(String, String)] = {
val body = req.body.asFormUrlEncoded.getOrElse(Map())
val clientId = body.getOrElse("client_id", Seq()).headOption
val clientSec = body.getOrElse("client_secret", Seq()).headOption
clientId -> clientSec match {
case (Some(id), Some(sec)) =>
logger.info(s"[clientAuthentication] - The client id and secret were found in the request body")
Some((id, sec))
case (_, _) =>
logger.warn(s"[clientAuthentication] - Either the client Id or secret was not included in the request body")
None
}
}

private def getPkceCodeVerifier(implicit req: Request[AnyContent]): Option[(String, String)] = {
val body = req.body.asFormUrlEncoded.getOrElse(Map())
val clientId = body.getOrElse("client_id", Seq()).headOption
val codeVerifier = body.getOrElse("code_verifier", Seq()).headOption
clientId -> codeVerifier match {
case (Some(id), Some(sec)) =>
logger.info(s"[clientAuthentication] - The client id and code verifier were found in the request body")
Some((id, sec))
case (_, _) =>
logger.warn(s"[clientAuthentication] - Either the client Id or code verifier was not included in the request body")
None
}
}

private def getClientCredFromHeader(implicit req: Request[_]): Option[(String, String)] = {
req.headers.get("Authorization").fold[Option[(String, String)]](noAuthHeader) { auth =>
val splitHeader = auth.split(" ")
if(splitHeader.head == "Basic") {
Try(Base64.getDecoder.decode(splitHeader.last)).fold(
Try(Base64.getDecoder.decode(splitHeader.last)).fold[Option[(String, String)]](
err => {
logger.warn(s"[clientAuthentication] - Basic auth header was found, but payload was not Base64", err)
Future.successful(Unauthorized(StandardErrors.INVALID_REQUEST))
None
},
basicAuthHeader => {
val Array(clientId, clientSecret) = new String(basicAuthHeader, StandardCharsets.UTF_8).split(":")
clientService.getRegisteredAppByIdAndSecret(clientId, clientSecret) flatMap {
case Some(app) =>
logger.info(s"[clientAuthentication] - Matched client with clientId $clientId")
f(req)(app)
case None =>
logger.warn(s"[clientAuthentication] - No client has been found matching clientId $clientId")
Future.successful(Unauthorized(StandardErrors.INVALID_CLIENT))
}
Some((clientId, clientSecret))
}
)
} else {
logger.warn(s"[clientAuthentication] - Auth header wasn't of type Basic")
Future.successful(Unauthorized(StandardErrors.INVALID_REQUEST))
None
}
}
}

private def noAuthHeader = {
logger.warn(s"[clientAuthentication] - No auth header found in the request")
None
}
}
2 changes: 1 addition & 1 deletion app/controllers/api/RevokationController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ trait RevokationController extends BaseController with BasicAuthAction {

val tokenOrchestrator: TokenOrchestrator

def revokeToken(): Action[AnyContent] = clientAuthentication { implicit req => _ =>
def revokeToken(): Action[AnyContent] = clientAuthentication { implicit req => _ => _ =>
val body = req.body.asFormUrlEncoded.getOrElse(Map())

val token = body.get("token").map(_.head)
Expand Down
17 changes: 8 additions & 9 deletions app/controllers/ui/OAuthController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,26 @@

package controllers.ui

import controllers.actions.AuthenticatedAction
import controllers.actions.{AuthenticatedAction, BasicAuthAction}
import javax.inject.Inject
import orchestrators._
import org.slf4j.LoggerFactory
import play.api.libs.json.{Json, Writes}
import play.api.mvc.{Action, AnyContent, BaseController, ControllerComponents}
import services.ClientService
import views.html.auth.Grant

import scala.concurrent.{Future, ExecutionContext => ExC}

class DefaultOAuthController @Inject()(val controllerComponents: ControllerComponents,
val tokenOrchestrator: TokenOrchestrator,
val grantOrchestrator: GrantOrchestrator,
val clientService: ClientService,
val userOrchestrator: UserOrchestrator) extends OAuthController {
override implicit val ec: ExC = controllerComponents.executionContext
}

trait OAuthController extends BaseController with AuthenticatedAction {
trait OAuthController extends BaseController with AuthenticatedAction with BasicAuthAction {

protected val grantOrchestrator: GrantOrchestrator
protected val tokenOrchestrator: TokenOrchestrator
Expand All @@ -46,7 +48,6 @@ trait OAuthController extends BaseController with AuthenticatedAction {
val idToken = issued.idToken.fold(Json.obj())(id => Json.obj("id_token" -> id))
val refreshToken = issued.refreshToken.fold(Json.obj())(refresh => Json.obj("refresh_token" -> refresh))


Json.obj(
"token_type" -> issued.tokenType,
"scope" -> issued.scope,
Expand All @@ -55,30 +56,28 @@ trait OAuthController extends BaseController with AuthenticatedAction {
) ++ idToken ++ refreshToken
}

def getToken(): Action[AnyContent] = Action.async { implicit req =>
def getToken(): Action[AnyContent] = clientAuthenticationOptionalPkce { implicit req => app => codeVerifier =>
val params = req.body.asFormUrlEncoded.getOrElse(Map())
val grantType = params("grant_type").headOption.getOrElse("")

logger.info(s"[getToken] - Attempting to issue tokens using the $grantType grant")
grantType match {
case "authorization_code" =>
val authCode = params("code").headOption.getOrElse("")
val clientId = params("client_id").headOption.getOrElse("")
val redirectUri = params("redirect_uri").headOption.getOrElse("")
val codeVerifier = params.getOrElse("code_verifier", Seq()).headOption
tokenOrchestrator.authorizationCodeGrant(authCode, clientId, redirectUri, codeVerifier) map {
tokenOrchestrator.authorizationCodeGrant(authCode, app, redirectUri, codeVerifier) map {
case iss@Issued(_,_,_,_,_,_) => Ok(Json.toJson(iss))
case resp => BadRequest(Json.obj("error" -> resp.toString))
}
case "client_credentials" =>
val scope = params("scope").headOption.getOrElse("")
tokenOrchestrator.clientCredentialsGrant(scope) map {
tokenOrchestrator.clientCredentialsGrant(app, scope) map {
case iss@Issued(_,_,_,_,_,_) => Ok(Json.toJson(iss))
case resp => BadRequest(Json.obj("error" -> resp.toString))
}
case "refresh_token" =>
val refreshToken = params("refresh_token").headOption.getOrElse("")
tokenOrchestrator.refreshTokenGrant(refreshToken) map {
tokenOrchestrator.refreshTokenGrant(app, refreshToken) map {
case iss@Issued(_,_,_,_,_,_) => Ok(Json.toJson(iss))
case resp => BadRequest(Json.obj("error" -> resp.toString))
}
Expand Down
13 changes: 11 additions & 2 deletions app/models/RegisteredApplication.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.UUID

import com.cjwwdev.security.obfuscation.Obfuscators
import com.cjwwdev.security.Implicits._
import com.cjwwdev.security.deobfuscation.DeObfuscators
import org.bson.codecs.configuration.CodecProvider
import org.joda.time.DateTime
import org.mongodb.scala.bson.codecs.Macros
Expand All @@ -41,13 +42,21 @@ case class RegisteredApplication(appId: String,
idTokenExpiry: Long,
accessTokenExpiry: Long,
refreshTokenExpiry: Long,
createdAt: DateTime)
createdAt: DateTime) {
}

object RegisteredApplication extends Obfuscators with TimeFormat {
object RegisteredApplication extends Obfuscators with DeObfuscators with TimeFormat {
override val locale: String = this.getClass.getCanonicalName

val codec: CodecProvider = Macros.createCodecProviderIgnoreNone[RegisteredApplication]()

def decode(app: RegisteredApplication): RegisteredApplication = {
app.copy(
clientId = stringDeObfuscate.decrypt(app.clientId).getOrElse(app.clientId),
clientSecret = app.clientSecret.map(sec => stringDeObfuscate.decrypt(sec).getOrElse(""))
)
}

def generateIds(iterations: Int): String = {
(0 to iterations)
.map(_ => UUID.randomUUID().toString.replace("-", ""))
Expand Down
Loading

0 comments on commit 82cf2a4

Please sign in to comment.