Skip to content

Commit

Permalink
conditionally formatting the policies endpoint url based on the curre…
Browse files Browse the repository at this point in the history
…nt (#469)

keycloak server version

Co-authored-by: Nerzal <tobias.theel@noobygames.de>
  • Loading branch information
osamaadam and Nerzal authored Nov 5, 2024
1 parent 75aae0c commit 51c8e6f
Showing 1 changed file with 84 additions and 3 deletions.
87 changes: 84 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/opentracing/opentracing-go"
"github.com/pkg/errors"
"github.com/segmentio/ksuid"
"golang.org/x/mod/semver"

"github.com/Nerzal/gocloak/v13/pkg/jwx"
)
Expand All @@ -36,6 +37,7 @@ type GoCloak struct {
logoutEndpoint string
openIDConnect string
attackDetection string
version string
}
}

Expand All @@ -51,6 +53,53 @@ func makeURL(path ...string) string {
return strings.Join(path, urlSeparator)
}

// Compares the provided version against the current version of the Keycloak server.
// Current version is fetched from the serverinfo if not already set.
//
// Returns:
//
// -1 if the provided version is lower than the server version
//
// 0 if the provided version is equal to the server version
//
// 1 if the provided version is higher than the server version
func (g *GoCloak) compareVersions(v, token string, ctx context.Context) (int, error) {
curVersion := g.Config.version
if curVersion == "" {
curV, err := g.getServerVersion(ctx, token);
if err != nil {
return 0, err
}

curVersion = curV
}

curVersion = "v" + g.Config.version
if (v[0] != 'v') {
v = "v" + v
}

return semver.Compare(curVersion, v), nil
}

// Get the server version from the serverinfo endpoint.
// If the version is already set, it will return the cached version.
// Otherwise, it will fetch the version from the serverinfo endpoint and cache it.
func (g *GoCloak) getServerVersion(ctx context.Context, token string) (string, error) {
if g.Config.version != "" {
return g.Config.version, nil
}

serverInfo, err := g.GetServerInfo(ctx, token)
if err != nil {
return "", err
}

g.Config.version = *(serverInfo.SystemInfo.Version)

return g.Config.version, nil
}

// GetRequest returns a request for calling endpoints.
func (g *GoCloak) GetRequest(ctx context.Context) *resty.Request {
var err HTTPErrorResponse
Expand Down Expand Up @@ -3539,8 +3588,14 @@ func (g *GoCloak) GetPolicies(ctx context.Context, token, realm, idOfClient stri
return nil, errors.Wrap(err, errMessage)
}

compResult, err := g.compareVersions("20.0.0", token, ctx)
if err != nil {
return nil, err
}
shouldAddType := compResult != 1

path := []string{"clients", idOfClient, "authz", "resource-server", "policy"}
if !NilOrEmpty(params.Type) {
if !NilOrEmpty(params.Type) && shouldAddType {
path = append(path, *params.Type)
}

Expand All @@ -3565,11 +3620,23 @@ func (g *GoCloak) CreatePolicy(ctx context.Context, token, realm, idOfClient str
return nil, errors.New("type of a policy required")
}

compResult, err := g.compareVersions("20.0.0", token, ctx)
if err != nil {
return nil, err
}
shouldAddType := compResult != 1

path := []string{"clients", idOfClient, "authz", "resource-server", "policy"}

if shouldAddType {
path = append(path, *policy.Type)
}

var result PolicyRepresentation
resp, err := g.GetRequestWithBearerAuth(ctx, token).
SetResult(&result).
SetBody(policy).
Post(g.getAdminRealmURL(realm, "clients", idOfClient, "authz", "resource-server", "policy", *(policy.Type)))
Post(g.getAdminRealmURL(realm, path...))

if err := checkForError(resp, err, errMessage); err != nil {
return nil, err
Expand All @@ -3586,9 +3653,23 @@ func (g *GoCloak) UpdatePolicy(ctx context.Context, token, realm, idOfClient str
return errors.New("ID of a policy required")
}

compResult, err := g.compareVersions("20.0.0", token, ctx)
if err != nil {
return err
}
shouldAddType := compResult != 1

path := []string{"clients", idOfClient, "authz", "resource-server", "policy"}

if shouldAddType {
path = append(path, *policy.Type)
}

path = append(path, *(policy.ID))

resp, err := g.GetRequestWithBearerAuth(ctx, token).
SetBody(policy).
Put(g.getAdminRealmURL(realm, "clients", idOfClient, "authz", "resource-server", "policy", *(policy.Type), *(policy.ID)))
Put(g.getAdminRealmURL(realm, path...))

return checkForError(resp, err, errMessage)
}
Expand Down

0 comments on commit 51c8e6f

Please sign in to comment.