diff --git a/server_test.go b/server_test.go index a51f8ea..6282319 100644 --- a/server_test.go +++ b/server_test.go @@ -219,15 +219,19 @@ func TestAuthorizationServer_retrieveClient(t *testing.T) { } func TestAuthorizationServer_writeJSON(t *testing.T) { - type fields struct{} + type fields struct { + allowedOrigin string + } type args struct { w http.ResponseWriter value interface{} } tests := []struct { - name string - args args - wantCode int + name string + fields fields + args args + wantCode int + wantHeader http.Header }{ { name: "stream error", @@ -238,12 +242,36 @@ func TestAuthorizationServer_writeJSON(t *testing.T) { }, }, wantCode: http.StatusInternalServerError, + wantHeader: http.Header{ + "Content-Type": []string{"text/plain; charset=utf-8"}, + "X-Content-Type-Options": []string{"nosniff"}, + }, + }, + { + name: "cors enabled", + fields: fields{ + allowedOrigin: "some-origin", + }, + args: args{ + w: &mock.MockResponseRecorder{ + ResponseRecorder: httptest.NewRecorder(), + WriteError: nil, + }, + value: string("test"), + }, + wantCode: http.StatusOK, + wantHeader: http.Header{ + "Content-Type": []string{"application/json"}, + "Access-Control-Allow-Origin": []string{"some-origin"}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - srv := &AuthorizationServer{} + srv := &AuthorizationServer{ + allowedOrigin: tt.fields.allowedOrigin, + } srv.writeJSON(tt.args.w, tt.args.value) var rr *httptest.ResponseRecorder @@ -258,6 +286,11 @@ func TestAuthorizationServer_writeJSON(t *testing.T) { if gotCode != tt.wantCode { t.Errorf("AuthorizationServer.writeJSON() code = %v, wantCode %v", gotCode, tt.wantCode) } + + gotHeader := rr.Header() + if !reflect.DeepEqual(gotHeader, tt.wantHeader) { + t.Errorf("AuthorizationServer.writeJSON() header = %v, wantHeader %v", gotCode, tt.wantCode) + } }) } }