Skip to content

Commit

Permalink
Allow multiauthenticator to try all authenticators (#1379)
Browse files Browse the repository at this point in the history
Previously we would fail if any authenticator returned an error. With
this change, each authenticator is attempted and if any return an error
the errors are aggregated and only returned if no token is successfully
set.
  • Loading branch information
wlynch authored Oct 30, 2024
1 parent 6a3e0cb commit e707968
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
12 changes: 10 additions & 2 deletions pkg/apk/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"context"
"errors"
"net/http"
"os"
"os/exec"
Expand Down Expand Up @@ -40,16 +41,23 @@ func MultiAuthenticator(auths ...Authenticator) Authenticator { return multiAuth
type multiAuthenticator []Authenticator

func (m multiAuthenticator) AddAuth(ctx context.Context, req *http.Request) error {
var merr error
for _, a := range m {
if _, _, ok := req.BasicAuth(); ok {
// The request has auth, so we can stop here.
return nil
}
if err := a.AddAuth(ctx, req); err != nil {
return err
merr = errors.Join(merr, err)
continue
}
}
return nil

// One last check at the end to see if we added auth, else return the aggregated error.
if _, _, ok := req.BasicAuth(); ok {
return nil
}
return merr
}

// EnvAuth adds HTTP basic auth to the request if the request URL matches the
Expand Down
72 changes: 72 additions & 0 deletions pkg/apk/auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package auth

import (
"context"
"errors"
"net/http"
"testing"
)

type successAuth struct{}

func (s successAuth) AddAuth(_ context.Context, req *http.Request) error {
req.SetBasicAuth("user", "pass")
return nil
}

type failAuth struct{}

func (f failAuth) AddAuth(_ context.Context, req *http.Request) error {
return errors.New("failed to add auth")
}

func TestMultiAuthenticator(t *testing.T) {
tests := []struct {
name string
auths []Authenticator
expectAuth bool
expectErr bool
}{
{
name: "success auth first",
auths: []Authenticator{successAuth{}, failAuth{}},
expectAuth: true,
expectErr: false,
},
{
name: "fail auth first",
auths: []Authenticator{failAuth{}, successAuth{}},
expectAuth: true,
expectErr: false,
},
{
name: "all fail auth",
auths: []Authenticator{failAuth{}, failAuth{}},
expectAuth: false,
expectErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
multiAuth := MultiAuthenticator(tt.auths...)
req, _ := http.NewRequest("GET", "http://example.com", nil)
err := multiAuth.AddAuth(context.Background(), req)

if tt.expectErr && err == nil {
t.Errorf("expected error but got none")
}
if !tt.expectErr && err != nil {
t.Errorf("did not expect error but got: %v", err)
}

user, pass, ok := req.BasicAuth()
if tt.expectAuth && !ok {
t.Errorf("expected auth but got none")
}
if !tt.expectAuth && ok {
t.Errorf("did not expect auth but got user: %s, pass: %s", user, pass)
}
})
}
}

0 comments on commit e707968

Please sign in to comment.