diff --git a/lib/httpserver/httpserver_test.go b/lib/httpserver/httpserver_test.go index c1adb7fcba..474ea44663 100644 --- a/lib/httpserver/httpserver_test.go +++ b/lib/httpserver/httpserver_test.go @@ -150,8 +150,8 @@ func TestHandlerWrapperOptionsRequest(t *testing.T) { handlerCalled = true return true } - - f := func(t *testing.T, name string, corsDisabled bool, expectAllowOrigin bool) { + headersToCheck := []string{"Access-Control-Allow-Origin", "Access-Control-Allow-Headers"} + f := func(t *testing.T, corsDisabled bool) { t.Helper() handlerCalled = false @@ -161,6 +161,10 @@ func TestHandlerWrapperOptionsRequest(t *testing.T) { *disableCORS = origDisableCORS }() + wantCORSHeaderValue := "*" + if corsDisabled { + wantCORSHeaderValue = "" + } req := httptest.NewRequest(http.MethodOptions, "/api/v1/query_range", nil) w := httptest.NewRecorder() @@ -170,31 +174,23 @@ func TestHandlerWrapperOptionsRequest(t *testing.T) { _ = res.Body.Close() if res.StatusCode != http.StatusNoContent { - t.Fatalf("%s: unexpected status code; got %d; want %d", name, res.StatusCode, http.StatusNoContent) + t.Fatalf("unexpected status code; (-%d;+%d)", http.StatusNoContent, res.StatusCode) } if handlerCalled { - t.Fatalf("%s: request handler must not be called for OPTIONS requests", name) + t.Fatalf("request handler must not be called for OPTIONS requests") } - if got := res.Header.Get("Access-Control-Allow-Methods"); got != "*" { - t.Fatalf("%s: unexpected Access-Control-Allow-Methods; got %q; want %q", name, got, "*") - } - wantHeaders := "*" - if got := res.Header.Get("Access-Control-Allow-Headers"); got != wantHeaders { - t.Fatalf("%s: unexpected Access-Control-Allow-Headers; got %q; want %q", name, got, wantHeaders) - } - if expectAllowOrigin { - if got := res.Header.Get("Access-Control-Allow-Origin"); got != "*" { - t.Fatalf("%s: unexpected Access-Control-Allow-Origin; got %q; want %q", name, got, "*") - } - } else { - if got := res.Header.Get("Access-Control-Allow-Origin"); got != "" { - t.Fatalf("%s: Access-Control-Allow-Origin must be empty when CORS is disabled; got %q", name, got) + for _, h := range headersToCheck { + got := res.Header.Get(h) + if wantCORSHeaderValue != got { + t.Fatalf("unexpected header: %s value: (-%s;+%s)", h, wantCORSHeaderValue, got) } } } - f(t, "cors enabled", false, true) - f(t, "cors disabled", true, false) + // CORS disabled + f(t, false) + // CORS enabled + f(t, true) } func TestHandlerWrapper(t *testing.T) {