diff --git a/apptest/tests/vmauth_routing_test.go b/apptest/tests/vmauth_routing_test.go index 2f9d604f3a..b530106e7e 100644 --- a/apptest/tests/vmauth_routing_test.go +++ b/apptest/tests/vmauth_routing_test.go @@ -1,8 +1,10 @@ package tests import ( + "context" "fmt" "io" + "net" "net/http" "net/http/httptest" "net/url" @@ -301,3 +303,82 @@ unauthorized_user: assertBackendsRequestsCount(1) } + +func TestSingleVMAuthUseProxyProtocol(t *testing.T) { + tc := apptest.NewTestCase(t) + defer tc.Stop() + + var requestsCount int + var actualForwardedForHeader string + backend := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + actualForwardedForHeader = r.Header.Get("X-Forwarded-For") + requestsCount++ + })) + defer backend.Close() + + authConfig := fmt.Sprintf(` +unauthorized_user: + url_prefix: %s + `, backend.URL) + + vmauth := tc.MustStartVmauth("vmauth", []string{ + "-httpListenAddr.useProxyProtocol=true", + }, authConfig) + + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/backend", vmauth.GetHTTPListenAddr()), nil) + if err != nil { + t.Fatalf("cannot build http.Request: %s", err) + } + + // make request using proxy protocol + c := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + + // Write a proxy protocol header to the connection + if _, err := conn.Write([]byte{ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, // signature + 0x21, // version 2 + 0x11, // family IPv4 + 0x00, 0x0C, // length: 12 bytes (IPv4 + ports) + 192, 168, 1, 100, // source IP + 10, 0, 0, 1, // destination IP + 0x1F, 0x90, // source port 8080 + 0x00, 0x50, // destination port 80 + }); err != nil { + t.Fatalf("cannot send proxy protocol header: %s", err) + } + + return conn, nil + }, + }, + } + + resp, err := c.Do(req) + if err != nil { + t.Fatalf("cannot make http.Get request for target=%q: %s", req.URL, err) + } + responseText, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("cannot read response body: %s", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected http response code: %d, want: %d, response text: %s", resp.StatusCode, http.StatusOK, responseText) + } + + // ensure that request was proxied + if requestsCount != 1 { + t.Fatalf("expected to have %d unauthorized proxied requests, got: %d", 1, requestsCount) + } + + // ensure that X-Forwarded-For header is set to the source IP from proxy protocol + expectedForwardedForHeader := "192.168.1.100" + if actualForwardedForHeader != expectedForwardedForHeader { + t.Fatalf("expected X-Forwarded-For header to be equal to proxy source IP, got: %s, want: %s'", actualForwardedForHeader, expectedForwardedForHeader) + } +}