mirror of
https://github.com/VictoriaMetrics/VictoriaMetrics.git
synced 2026-06-01 08:02:53 +03:00
Compare commits
4 Commits
debug-grou
...
vmauth-jwt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22a60f0b15 | ||
|
|
513f5a66d6 | ||
|
|
5b053c9bbb | ||
|
|
3e19926a7a |
@@ -875,12 +875,13 @@ func reloadAuthConfigData(data []byte) (bool, error) {
|
||||
return false, fmt.Errorf("failed to parse auth config: %w", err)
|
||||
}
|
||||
|
||||
jui, err := parseJWTUsers(ac)
|
||||
jui, ds, err := parseJWTUsers(ac)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to parse JWT users from auth config: %w", err)
|
||||
}
|
||||
jwtc := &jwtCache{
|
||||
users: jui,
|
||||
ds: ds,
|
||||
}
|
||||
|
||||
m, err := parseAuthConfigUsers(ac)
|
||||
@@ -899,6 +900,11 @@ func reloadAuthConfigData(data []byte) (bool, error) {
|
||||
}
|
||||
metrics.RegisterSet(ac.ms)
|
||||
|
||||
jwtcPrev := jwtAuthCache.Load()
|
||||
if jwtcPrev != nil {
|
||||
jwtcPrev.ds.stop()
|
||||
}
|
||||
|
||||
authConfig.Store(ac)
|
||||
authConfigData.Store(&data)
|
||||
authUsers.Store(&m)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/VictoriaMetrics/VictoriaMetrics/lib/jwt"
|
||||
@@ -44,18 +47,34 @@ var urlPathPlaceHolders = []string{
|
||||
type jwtCache struct {
|
||||
// users contain UserInfo`s from AuthConfig with JWTConfig set
|
||||
users []*UserInfo
|
||||
|
||||
ds *oidcDiscoverers
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
PublicKeys []string `yaml:"public_keys,omitempty"`
|
||||
PublicKeyFiles []string `yaml:"public_key_files,omitempty"`
|
||||
SkipVerify bool `yaml:"skip_verify,omitempty"`
|
||||
PublicKeys []string `yaml:"public_keys,omitempty"`
|
||||
PublicKeyFiles []string `yaml:"public_key_files,omitempty"`
|
||||
SkipVerify bool `yaml:"skip_verify,omitempty"`
|
||||
OIDC *OIDCConfig `yaml:"oidc,omitempty"`
|
||||
|
||||
verifierPool *jwt.VerifierPool
|
||||
// verifierPool is used to verify JWT tokens.
|
||||
// It is initialized from PublicKeys and/or PublicKeyFiles.
|
||||
// In this case, it is initialized once at config reload and never updated until next reload
|
||||
// In case of OIDC, it is initialized on config reload and periodically updated by discovery process.
|
||||
verifierPool atomic.Pointer[jwt.VerifierPool]
|
||||
}
|
||||
|
||||
func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
|
||||
func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, *oidcDiscoverers, error) {
|
||||
jui := make([]*UserInfo, 0, len(ac.Users))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ds := &oidcDiscoverers{
|
||||
ds: make(map[string]*oidcDiscoverer),
|
||||
context: ctx,
|
||||
cancel: cancel,
|
||||
wg: &sync.WaitGroup{},
|
||||
}
|
||||
|
||||
for _, ui := range ac.Users {
|
||||
jwtToken := ui.JWT
|
||||
if jwtToken == nil {
|
||||
@@ -63,10 +82,10 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
|
||||
}
|
||||
|
||||
if ui.AuthToken != "" || ui.BearerToken != "" || ui.Username != "" || ui.Password != "" {
|
||||
return nil, fmt.Errorf("auth_token, bearer_token, username and password cannot be specified if jwt is set")
|
||||
return nil, nil, fmt.Errorf("auth_token, bearer_token, username and password cannot be specified if jwt is set")
|
||||
}
|
||||
if len(jwtToken.PublicKeys) == 0 && len(jwtToken.PublicKeyFiles) == 0 && !jwtToken.SkipVerify {
|
||||
return nil, fmt.Errorf("jwt must contain at least a single public key, public_key_files or have skip_verify=true")
|
||||
if len(jwtToken.PublicKeys) == 0 && len(jwtToken.PublicKeyFiles) == 0 && !jwtToken.SkipVerify && jwtToken.OIDC == nil {
|
||||
return nil, nil, fmt.Errorf("jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true")
|
||||
}
|
||||
|
||||
if len(jwtToken.PublicKeys) > 0 || len(jwtToken.PublicKeyFiles) > 0 {
|
||||
@@ -75,7 +94,7 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
|
||||
for i := range jwtToken.PublicKeys {
|
||||
k, err := jwt.ParseKey([]byte(jwtToken.PublicKeys[i]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
@@ -83,33 +102,45 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
|
||||
for _, filePath := range jwtToken.PublicKeyFiles {
|
||||
keyData, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read public key from file %q: %w", filePath, err)
|
||||
return nil, nil, fmt.Errorf("cannot read public key from file %q: %w", filePath, err)
|
||||
}
|
||||
k, err := jwt.ParseKey(keyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse public key from file %q: %w", filePath, err)
|
||||
return nil, nil, fmt.Errorf("cannot parse public key from file %q: %w", filePath, err)
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
vp, err := jwt.NewVerifierPool(keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
jwtToken.verifierPool = vp
|
||||
jwtToken.verifierPool.Store(vp)
|
||||
}
|
||||
if jwtToken.OIDC != nil {
|
||||
if len(jwtToken.PublicKeys) > 0 || len(jwtToken.PublicKeyFiles) > 0 || jwtToken.SkipVerify {
|
||||
return nil, nil, fmt.Errorf("jwt with oidc cannot contain public keys or have skip_verify=true")
|
||||
}
|
||||
|
||||
if jwtToken.OIDC.Issuer == "" {
|
||||
return nil, nil, fmt.Errorf("oidc issuer cannot be empty")
|
||||
}
|
||||
|
||||
ds.add(jwtToken.OIDC.Issuer, &jwtToken.verifierPool)
|
||||
}
|
||||
|
||||
if err := parseJWTPlaceholdersForUserInfo(&ui, true); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := ui.initURLs(); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
metricLabels, err := ui.getMetricLabels()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse metric_labels: %w", err)
|
||||
return nil, nil, fmt.Errorf("cannot parse metric_labels: %w", err)
|
||||
}
|
||||
ui.requests = ac.ms.GetOrCreateCounter(`vmauth_user_requests_total` + metricLabels)
|
||||
ui.requestErrors = ac.ms.GetOrCreateCounter(`vmauth_user_request_errors_total` + metricLabels)
|
||||
@@ -128,7 +159,7 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
|
||||
|
||||
rt, err := newRoundTripper(ui.TLSCAFile, ui.TLSCertFile, ui.TLSKeyFile, ui.TLSServerName, ui.TLSInsecureSkipVerify)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err)
|
||||
return nil, nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err)
|
||||
}
|
||||
ui.rt = rt
|
||||
|
||||
@@ -137,10 +168,12 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
|
||||
|
||||
// TODO: the limitation will be lifted once claim based matching will be implemented
|
||||
if len(jui) > 1 {
|
||||
return nil, fmt.Errorf("multiple users with JWT tokens are not supported; found %d users", len(jui))
|
||||
return nil, nil, fmt.Errorf("multiple users with JWT tokens are not supported; found %d users", len(jui))
|
||||
}
|
||||
|
||||
return jui, nil
|
||||
ds.start()
|
||||
|
||||
return jui, ds, nil
|
||||
}
|
||||
|
||||
func getUserInfoByJWTToken(ats []string) (*UserInfo, *jwt.Token) {
|
||||
@@ -177,7 +210,31 @@ func getUserInfoByJWTToken(ats []string) (*UserInfo, *jwt.Token) {
|
||||
return ui, tkn
|
||||
}
|
||||
|
||||
if err := ui.JWT.verifierPool.Verify(tkn); err != nil {
|
||||
if ui.JWT.OIDC != nil {
|
||||
// OIDC requires iss claim.
|
||||
// It must match the discovery issuer URL set in OIDC config.
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
if tkn.Issuer() == "" {
|
||||
continue
|
||||
}
|
||||
if tkn.Issuer() != ui.JWT.OIDC.Issuer {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
vp := ui.JWT.verifierPool.Load()
|
||||
|
||||
// Could be nil in case OIDC discovery has not succeeded yet.
|
||||
if vp == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// In case of OIDC, current verifier implementation is suboptimal.
|
||||
// It tries all keys for the same alg until it finds the right one or fails all of them.
|
||||
// OIDC require using kid claim from the token to choose a proper JWK.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RotateEncKeys
|
||||
// https://openid.net/specs/draft-jones-json-web-key-03.html#anchor4
|
||||
if err := vp.Verify(tkn); err != nil {
|
||||
if *logInvalidAuthTokens {
|
||||
logger.Infof("cannot verify jwt token: %s", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -36,13 +39,14 @@ XOtclIk1uhc03oL9nOQ=
|
||||
}
|
||||
return
|
||||
}
|
||||
users, err := parseJWTUsers(ac)
|
||||
users, ds, err := parseJWTUsers(ac)
|
||||
if err != nil {
|
||||
if expErr != err.Error() {
|
||||
t.Fatalf("unexpected error; got\n%q\nwant \n%q", err.Error(), expErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
ds.stop()
|
||||
t.Fatalf("expecting non-nil error; got %v", users)
|
||||
}
|
||||
|
||||
@@ -80,28 +84,28 @@ users:
|
||||
users:
|
||||
- jwt: {}
|
||||
url_prefix: http://foo.bar
|
||||
`, `jwt must contain at least a single public key, public_key_files or have skip_verify=true`)
|
||||
`, `jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true`)
|
||||
|
||||
// jwt public_keys or skip_verify must be set, part 2
|
||||
f(`
|
||||
users:
|
||||
- jwt: {public_keys: null}
|
||||
url_prefix: http://foo.bar
|
||||
`, `jwt must contain at least a single public key, public_key_files or have skip_verify=true`)
|
||||
`, `jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true`)
|
||||
|
||||
// jwt public_keys or skip_verify must be set, part 3
|
||||
f(`
|
||||
users:
|
||||
- jwt: {public_keys: []}
|
||||
url_prefix: http://foo.bar
|
||||
`, `jwt must contain at least a single public key, public_key_files or have skip_verify=true`)
|
||||
`, `jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true`)
|
||||
|
||||
// jwt public_keys, public_key_files or skip_verify must be set
|
||||
f(`
|
||||
users:
|
||||
- jwt: {public_key_files: []}
|
||||
url_prefix: http://foo.bar
|
||||
`, `jwt must contain at least a single public key, public_key_files or have skip_verify=true`)
|
||||
`, `jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true`)
|
||||
|
||||
// invalid public key, part 1
|
||||
f(`
|
||||
@@ -196,6 +200,51 @@ users:
|
||||
`,
|
||||
"request header: \"AccountID\" has unsupported placeholder: \"{{ .LogsAccountID }}\", supported values are: {{.MetricsTenant}}, {{.MetricsExtraLabels}}, {{.MetricsExtraFilters}}, {{.LogsAccountID}}, {{.LogsProjectID}}, {{.LogsExtraFilters}}, {{.LogsExtraStreamFilters}}",
|
||||
)
|
||||
|
||||
// oidc is not an object
|
||||
f(`
|
||||
users:
|
||||
- jwt:
|
||||
oidc: "not an object"
|
||||
url_prefix: http://foo.bar
|
||||
`,
|
||||
"cannot unmarshal AuthConfig data: yaml: unmarshal errors:\n line 4: cannot unmarshal !!str `not an ...` into main.OIDCConfig",
|
||||
)
|
||||
|
||||
// oidc issuer empty
|
||||
f(`
|
||||
users:
|
||||
- jwt:
|
||||
oidc: {}
|
||||
url_prefix: http://foo.bar
|
||||
`,
|
||||
"oidc issuer cannot be empty",
|
||||
)
|
||||
|
||||
// oidc and public_keys are not allowed
|
||||
f(fmt.Sprintf(`
|
||||
users:
|
||||
- jwt:
|
||||
public_keys:
|
||||
- %q
|
||||
oidc:
|
||||
issuer: https://example.com
|
||||
url_prefix: http://foo.bar
|
||||
`, validRSAPublicKey),
|
||||
"jwt with oidc cannot contain public keys or have skip_verify=true",
|
||||
)
|
||||
|
||||
// oidc and skip_verify are not allowed
|
||||
f(`
|
||||
users:
|
||||
- jwt:
|
||||
skip_verify: true
|
||||
oidc:
|
||||
issuer: https://example.com
|
||||
url_prefix: http://foo.bar
|
||||
`,
|
||||
"jwt with oidc cannot contain public keys or have skip_verify=true",
|
||||
)
|
||||
}
|
||||
|
||||
func TestJWTParseAuthConfigSuccess(t *testing.T) {
|
||||
@@ -225,10 +274,11 @@ XOtclIk1uhc03oL9nOQ=
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
jui, err := parseJWTUsers(ac)
|
||||
jui, ds, err := parseJWTUsers(ac)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
defer ds.stop()
|
||||
|
||||
for _, ui := range jui {
|
||||
if ui.JWT == nil {
|
||||
@@ -236,13 +286,13 @@ XOtclIk1uhc03oL9nOQ=
|
||||
}
|
||||
|
||||
if ui.JWT.SkipVerify {
|
||||
if ui.JWT.verifierPool != nil {
|
||||
if ui.JWT.verifierPool.Load() != nil {
|
||||
t.Fatalf("unexpected non-nil verifier pool for skip_verify=true")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ui.JWT.verifierPool == nil {
|
||||
if ui.JWT.verifierPool.Load() == nil {
|
||||
t.Fatalf("unexpected nil verifier pool for non-empty public keys")
|
||||
}
|
||||
}
|
||||
@@ -333,4 +383,51 @@ users:
|
||||
- %q
|
||||
url_prefix: http://foo.bar
|
||||
`, validECDSAPublicKey, rsaKeyFile))
|
||||
|
||||
// oidc stub server
|
||||
var ipSrv *httptest.Server
|
||||
ipSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"issuer": ipSrv.URL,
|
||||
"jwks_uri": fmt.Sprintf("%s/jwks", ipSrv.URL),
|
||||
})
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/jwks" {
|
||||
// reps generated by https://jwkset.com/generate
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kty": "RSA",
|
||||
"kid": "f13eee91-f566-4829-80fa-fca847c21f0e",
|
||||
"d": "Ua1llEFz3LZ05CrK5a2JxKMUEWJGXhBPPF20hHQjzxd1w0IEJK_mhPZQG8dNtBROBNIi1FC9l6QRw-RTnVIVat5Xy4yDFNKXXL3ZLXejOHY8SXrNEIDqQ-cSwIpK9cK7Umib0PcPeEeeAED5mqDH75D8_YssWFF18kLbNB5Z9pZmn6Fshiht7l2Sh4GN-KcReOW6eiQQwckDte3OGmZCRbtEriLWJt5TUGUvfZVIlcclqNMycNB6jGa9E1pO5Up7Ki3ZbI_-6XmRgZPtqnR9oLJ1zn3fj3hYpCXo-zcqLuOu3qxcslsq5igsfBzgGtfIJHY9LfWmHUsaDEa5cAX1gQ",
|
||||
"n": "xbLXXBTNREk70UCMiqZ53_mTzYh89W-UaPU61GZ-RZ5lYcLgyWOb5mdyRbvJpcgfZpsOeGAUWbk3GkQ4vqn8kUMnnWhUum2Qk9kGubOJGLW6yaURd00j3E-ilQ5xO2R_Hzz8bAojxV8GKdGTQ-iTf8z8nsSHH8kR2SERbNJCFFtwtFU7vyFWyoH4Lmvu2UpICTHFCR9RqwQVjyoKB1JjJ6Dh1L4zPTlsvQEnqoeFQHPYr0QcQSMYXdfPvlt_FiLOAOE89fX_9T2r9WbFAoda3uTRE5_aal0jxUU2cFyeVSIgauNtF07fp422XFb4XPkWQWrdNx0KX53laSIYQ9HOpw",
|
||||
"e": "AQAB",
|
||||
"p": "2JT57AD-Q2lamgjgyn0wL7DgYZ3OoCTTrDm5_NHg6h13uDvyIlXSukuUeWm4tzPSDedpstbS7dgXkLw5eQXBHwPYtByTcEZS8Z37CBnhMOOhfo_U1aNIPPanJACvWBgz47-TxHsxW1YhztZqghRoicBZPSSBAj49MgANJ4jF0zc",
|
||||
"q": "6a4MkeSXJI-ZzQ-bgP8hwJqpLFr0AiNGQcjZMH4Nn4CPGdnGiqqe6flhfLimgbNhbb67B0-8fLIji8zGhGKDL_JSIpAAdmfs2vzeEsY2hScrqVbd1VbfRcRh0J6lsn7obxkbvQthp9sX2DQbeDcEeaFEvd9gDKQSATYEqWo7eBE",
|
||||
"dp": "haL2yu6Z9RJuuxi7S3YPY33qFZF_y0St71j3L854zzw7gMxMTW9TRWwZQwk-1pv9AmNFzvnK0MNDVyUs-UXZsb932TrApshdqYRnPsppLvdl0GgDVYcYrbUr0IUzrFHSwraVAOlavRbaaXvX4EejcUvkRFvf1nh83fs2Iqy8E-U",
|
||||
"dq": "Cnf5qC-Ndd3ZDg688LJ9WJuVKJ-Kfu4Fn7zXvgxnn9Wqk4XmFyA9rk21yFidXQIkQz5gMpun3g48-W5bFmMzbVp1w4af_q35NnZNnJm0p5Jxqkxx87TIm9-IYkg5NB3rW87MJ1PzNAnkr5LmCCSu1qQa6Eaxjt9qzxMUcmKH94E",
|
||||
"qi": "saAeU11iaKHmye3cwCAYkegcyWbXV3xIXEVJtS9Af_yM19UhspwY2VhuwRaajcwYZwtvR9_ITmX9M-ea7uLdd7aDYO1fujC8NGbopeC4Hkr7yb5vTly3pfKf4h-3LwGGUucJUetdz1lmMIYiyuG4_gSf1yIEtPDLKzXiedgEMdI"
|
||||
}
|
||||
]
|
||||
}
|
||||
`))
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer ipSrv.Close()
|
||||
|
||||
f(`
|
||||
users:
|
||||
- jwt:
|
||||
oidc:
|
||||
issuer: ` + ipSrv.URL + `
|
||||
url_prefix: http://foo.bar
|
||||
`)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -1235,7 +1236,142 @@ users:
|
||||
request,
|
||||
responseExpected,
|
||||
)
|
||||
}
|
||||
|
||||
func TestOIDCRequestHandler(t *testing.T) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot generate RSA key: %s", err)
|
||||
}
|
||||
|
||||
var oidcSrv *httptest.Server
|
||||
oidcRespOK := atomic.Bool{}
|
||||
oidcRespOK.Store(true)
|
||||
|
||||
oidcSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"issuer": oidcSrv.URL,
|
||||
"jwks_uri": oidcSrv.URL + "/jwks",
|
||||
}); err != nil {
|
||||
panic(fmt.Errorf("cannot write openid-configuration response: %w", err))
|
||||
}
|
||||
case "/jwks":
|
||||
if !oidcRespOK.Load() {
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Encode the RSA public key in JWK format (base64url, no padding)
|
||||
nBytes := privateKey.PublicKey.N.Bytes()
|
||||
eBytes := big.NewInt(int64(privateKey.PublicKey.E)).Bytes()
|
||||
jwksBody := fmt.Sprintf(`{"keys":[{"kty":"RSA","kid":%q,"n":%q,"e":%q}]}`,
|
||||
`test-key-id`,
|
||||
base64.RawURLEncoding.EncodeToString(nBytes),
|
||||
base64.RawURLEncoding.EncodeToString(eBytes),
|
||||
)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write([]byte(jwksBody)); err != nil {
|
||||
panic(fmt.Errorf("cannot write jwks response: %w", err))
|
||||
}
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer oidcSrv.Close()
|
||||
|
||||
headerJSON, err := json.Marshal(map[string]any{
|
||||
"alg": "RS256",
|
||||
"typ": "JWT",
|
||||
"iss": oidcSrv.URL,
|
||||
"kid": `test-key-id`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("cannot marshal JWT header: %s", err)
|
||||
}
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
|
||||
bodyJSON, err := json.Marshal(map[string]any{
|
||||
"exp": 9999999999,
|
||||
"iss": oidcSrv.URL,
|
||||
"vm_access": map[string]any{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("cannot marshal JWT body: %s", err)
|
||||
}
|
||||
bodyB64 := base64.RawURLEncoding.EncodeToString(bodyJSON)
|
||||
|
||||
payload := headerB64 + "." + bodyB64
|
||||
|
||||
var signatureB64 string
|
||||
hash := crypto.SHA256
|
||||
h := hash.New()
|
||||
h.Write([]byte(payload))
|
||||
digest := h.Sum(nil)
|
||||
|
||||
signature, err := rsa.SignPKCS1v15(rand.Reader, privateKey, hash, digest)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot sign JWT token: %s", err)
|
||||
}
|
||||
signatureB64 = base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
tkn := payload + "." + signatureB64
|
||||
|
||||
backSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer backSrv.Close()
|
||||
|
||||
f := func(responseExpected string) {
|
||||
t.Helper()
|
||||
|
||||
cfgStr := `
|
||||
users:
|
||||
- jwt:
|
||||
oidc:
|
||||
issuer: ` + oidcSrv.URL + `
|
||||
url_prefix: ` + backSrv.URL + `/
|
||||
`
|
||||
|
||||
cfgOrigP := authConfigData.Load()
|
||||
if _, err := reloadAuthConfigData([]byte(cfgStr)); err != nil {
|
||||
t.Fatalf("cannot load config data: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
cfgOrig := []byte("unauthorized_user:\n url_prefix: http://foo/bar")
|
||||
if cfgOrigP != nil {
|
||||
cfgOrig = *cfgOrigP
|
||||
}
|
||||
if _, err := reloadAuthConfigData(cfgOrig); err != nil {
|
||||
t.Fatalf("cannot restore original config: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
r := httptest.NewRequest("GET", "http://some-host.com/api/v1/query", nil)
|
||||
r.Header.Set("Authorization", "Bearer "+tkn)
|
||||
|
||||
w := &fakeResponseWriter{}
|
||||
if !requestHandlerWithInternalRoutes(w, r) {
|
||||
t.Fatalf("unexpected false returned from requestHandler")
|
||||
}
|
||||
|
||||
if response := w.getResponse(); response != responseExpected {
|
||||
t.Fatalf("unexpected response\ngot\n%s\nwant\n%s", response, responseExpected)
|
||||
}
|
||||
}
|
||||
|
||||
// successful
|
||||
f(`statusCode=200
|
||||
`)
|
||||
|
||||
oidcRespOK.Store(false)
|
||||
// OIDC server error
|
||||
f(`statusCode=401
|
||||
Unauthorized
|
||||
`)
|
||||
}
|
||||
|
||||
type fakeResponseWriter struct {
|
||||
|
||||
259
app/vmauth/oidc.go
Normal file
259
app/vmauth/oidc.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/VictoriaMetrics/VictoriaMetrics/lib/jwt"
|
||||
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
|
||||
)
|
||||
|
||||
type OIDCConfig struct {
|
||||
Issuer string `yaml:"issuer"`
|
||||
}
|
||||
|
||||
type oidcDiscoverers struct {
|
||||
ds map[string]*oidcDiscoverer
|
||||
context context.Context
|
||||
cancel func()
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (ds *oidcDiscoverers) stop() {
|
||||
ds.cancel()
|
||||
ds.wg.Wait()
|
||||
}
|
||||
|
||||
func (ds *oidcDiscoverers) add(iss string, vp *atomic.Pointer[jwt.VerifierPool]) {
|
||||
d, ok := ds.ds[iss]
|
||||
if !ok {
|
||||
d = &oidcDiscoverer{
|
||||
issuer: iss,
|
||||
context: ds.context,
|
||||
cancel: ds.cancel,
|
||||
wg: ds.wg,
|
||||
}
|
||||
ds.ds[iss] = d
|
||||
}
|
||||
|
||||
d.vps = append(d.vps, vp)
|
||||
}
|
||||
|
||||
func (ds *oidcDiscoverers) start() {
|
||||
for _, d := range ds.ds {
|
||||
d.start()
|
||||
}
|
||||
}
|
||||
|
||||
type oidcDiscoverer struct {
|
||||
vps []*atomic.Pointer[jwt.VerifierPool]
|
||||
issuer string
|
||||
|
||||
context context.Context
|
||||
cancel func()
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (d *oidcDiscoverer) start() {
|
||||
if err := d.refreshVerifierPool(); err != nil {
|
||||
logger.Errorf("failed to refresh OIDC verifier pool for issuer %q: %v", d.issuer, err)
|
||||
}
|
||||
|
||||
d.wg.Go(func() {
|
||||
t := time.NewTimer(time.Second * 10)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
if err := d.refreshVerifierPool(); err != nil {
|
||||
t.Reset(time.Second * 10)
|
||||
logger.Errorf("failed to refresh OIDC verifier pool for issuer %q: %v", d.issuer, err)
|
||||
}
|
||||
// OIDC may reutrn Cache-Control header with max-age directive.
|
||||
// It could be used as time rage for next refresh.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RotateEncKeys
|
||||
t.Reset(time.Minute * 5)
|
||||
case <-d.context.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (d *oidcDiscoverer) refreshVerifierPool() error {
|
||||
cfg, err := getOpenIDConfiguration(d.issuer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// The issuer in the OIDC configuration must match the expected issuer.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RotateEncKeys
|
||||
if cfg.Issuer != d.issuer {
|
||||
return fmt.Errorf("openid configuration issuer %q does not match expected issuer %q", cfg.Issuer, d.issuer)
|
||||
}
|
||||
|
||||
keys, err := fetchJWKs(cfg.JWKsURI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
verifierPool, err := jwt.NewVerifierPool(keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, vp := range d.vps {
|
||||
vp.Store(verifierPool)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type jwksResponse struct {
|
||||
Keys []jwk `json:"keys"`
|
||||
}
|
||||
|
||||
// See https://www.rfc-editor.org/rfc/rfc7517 for details.
|
||||
type jwk struct {
|
||||
Type string `json:"kty"`
|
||||
Alg string `json:"alg"`
|
||||
Use string `json:"use"`
|
||||
Kid string `json:"kid"`
|
||||
|
||||
// RSA keys contents
|
||||
E string `json:"e"`
|
||||
N string `json:"n"`
|
||||
|
||||
// EC keys contents
|
||||
Crv string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata for details.
|
||||
type openidConfig struct {
|
||||
Issuer string `json:"issuer"`
|
||||
JWKsURI string `json:"jwks_uri"`
|
||||
}
|
||||
|
||||
func fetchJWKs(jwksURI string) ([]any, error) {
|
||||
resp, err := http.Get(jwksURI)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch jwks keys from %q: %v", jwksURI, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code %d when fetching jwks keys from %q", resp.StatusCode, jwksURI)
|
||||
}
|
||||
|
||||
var jwks jwksResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to decode jwks response from %q: %v", jwksURI, err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
keys, err := parseJwksKeys(&jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse jwks keys from %q: %v", jwksURI, err)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func getOpenIDConfiguration(issuer string) (openidConfig, error) {
|
||||
issuer, _ = strings.CutSuffix(issuer, "/")
|
||||
configURL := fmt.Sprintf("%s/.well-known/openid-configuration", issuer)
|
||||
|
||||
resp, err := http.Get(configURL)
|
||||
if err != nil {
|
||||
return openidConfig{}, fmt.Errorf("failed to fetch openid config from %q: %v", configURL, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return openidConfig{}, fmt.Errorf("unexpected status code %d when fetching openid config from %q", resp.StatusCode, configURL)
|
||||
}
|
||||
|
||||
var cfg openidConfig
|
||||
if err := json.NewDecoder(resp.Body).Decode(&cfg); err != nil {
|
||||
return openidConfig{}, fmt.Errorf("failed to decode openid config from %q: %v", configURL, err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func parseJwksKeys(resp *jwksResponse) ([]any, error) {
|
||||
keys := make(map[string]any)
|
||||
for _, key := range resp.Keys {
|
||||
if key.Kid == "" {
|
||||
return nil, fmt.Errorf("jwks key without kid found")
|
||||
}
|
||||
|
||||
switch key.Type {
|
||||
case "RSA":
|
||||
if key.E == "" || key.N == "" {
|
||||
return nil, fmt.Errorf("jwks key without e or n found")
|
||||
}
|
||||
e, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode jwks key e: %w", err)
|
||||
}
|
||||
n, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode jwks key n: %w", err)
|
||||
}
|
||||
keys[key.Kid] = &rsa.PublicKey{
|
||||
E: int(big.NewInt(0).SetBytes(e).Int64()),
|
||||
N: big.NewInt(0).SetBytes(n),
|
||||
}
|
||||
case "EC":
|
||||
if key.Crv == "" || key.X == "" || key.Y == "" {
|
||||
return nil, fmt.Errorf("jwks key without crv or x or y found")
|
||||
}
|
||||
x, err := base64.RawURLEncoding.DecodeString(key.X)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode jwks key x: %w", err)
|
||||
}
|
||||
y, err := base64.RawURLEncoding.DecodeString(key.Y)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode jwks key y: %w", err)
|
||||
}
|
||||
var curve elliptic.Curve
|
||||
switch key.Crv {
|
||||
case "P-256":
|
||||
curve = elliptic.P256()
|
||||
case "P-384":
|
||||
curve = elliptic.P384()
|
||||
case "P-521":
|
||||
curve = elliptic.P521()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported jwks key crv %q found", key.Crv)
|
||||
}
|
||||
keys[key.Kid] = &ecdsa.PublicKey{
|
||||
Curve: curve,
|
||||
X: big.NewInt(0).SetBytes(x),
|
||||
Y: big.NewInt(0).SetBytes(y),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keysValues := make([]any, 0)
|
||||
for _, key := range keys {
|
||||
keysValues = append(keysValues, key)
|
||||
}
|
||||
|
||||
return keysValues, nil
|
||||
}
|
||||
@@ -99,6 +99,7 @@ type body struct {
|
||||
Exp int64 `json:"exp"`
|
||||
// issued at time unix_ts
|
||||
Iat int64 `json:"iat"`
|
||||
Iss string `json:"iss"`
|
||||
Jti string `json:"jti,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
vmAccessClaim VMAccessClaim
|
||||
@@ -138,6 +139,14 @@ func (b *body) parse(src string) error {
|
||||
return fmt.Errorf("cannot parse `iat` field: %w", err)
|
||||
}
|
||||
}
|
||||
if issObject := jv.Get("iss"); issObject != nil {
|
||||
bIss, err := issObject.StringBytes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse `iss` field: %w", err)
|
||||
}
|
||||
b.Iss = bytesutil.ToUnsafeString(bIss)
|
||||
}
|
||||
|
||||
vaObject := jv.Get("vm_access")
|
||||
if vaObject == nil {
|
||||
return ErrVMAccessFieldMissing
|
||||
@@ -265,6 +274,11 @@ func (t *Token) HasClaims(claims map[string]string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Issuer returns `iss` claim value from token body
|
||||
func (t *Token) Issuer() string {
|
||||
return t.body.Iss
|
||||
}
|
||||
|
||||
// VMAccess return a reference to the VMAccessClaim
|
||||
// all data are valid until Token is reachable
|
||||
func (t *Token) VMAccess() *VMAccessClaim {
|
||||
|
||||
@@ -305,6 +305,13 @@ func TestParseJWTBody_Failure(t *testing.T) {
|
||||
`unexpected type for key="logs_extra_stream_filters", got: string, want: array string`,
|
||||
true,
|
||||
)
|
||||
|
||||
// invalid iss claim value type
|
||||
f(
|
||||
`{"iss": {}, "vm_access": {}}`,
|
||||
"cannot parse `iss` field: value doesn't contain string; it contains object",
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
func TestParseJWTBody_Success(t *testing.T) {
|
||||
@@ -326,6 +333,9 @@ func TestParseJWTBody_Success(t *testing.T) {
|
||||
if result.Iat != resultExpected.Iat {
|
||||
t.Fatalf("unexpected Iat; got %d; want %d", result.Iat, resultExpected.Iat)
|
||||
}
|
||||
if result.Iss != resultExpected.Iss {
|
||||
t.Fatalf("unexpected Iss; got %s; want %s", result.Iss, resultExpected.Iss)
|
||||
}
|
||||
if result.Scope != resultExpected.Scope {
|
||||
t.Fatalf("unexpected scope; got %q; want %q", result.Scope, resultExpected.Scope)
|
||||
}
|
||||
@@ -349,6 +359,10 @@ func TestParseJWTBody_Success(t *testing.T) {
|
||||
f(`{"vm_access": {"tenant_id": {}}}`, &body{
|
||||
vmAccessClaim: VMAccessClaim{},
|
||||
})
|
||||
f(`{"iss": "theIssuer", "vm_access": {"tenant_id": {}}}`, &body{
|
||||
Iss: "theIssuer",
|
||||
vmAccessClaim: VMAccessClaim{},
|
||||
})
|
||||
|
||||
f(
|
||||
`
|
||||
|
||||
Reference in New Issue
Block a user