Compare commits

...

4 Commits

Author SHA1 Message Date
Max Kotliar
22a60f0b15 test 2026-03-05 21:29:48 +02:00
Max Kotliar
513f5a66d6 app/vmauth: group discovery by issuer 2026-03-05 21:29:48 +02:00
Max Kotliar
5b053c9bbb Update app/vmauth/jwt.go
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Signed-off-by: Max Kotliar <kotlyar.maksim@gmail.com>
2026-03-05 19:50:05 +02:00
Max Kotliar
3e19926a7a app/vmauth: Implement OpenID Connect Discovery support
Fixes https://github.com/VictoriaMetrics/VictoriaMetrics/issues/10585

app/vmauth: add comment for verifierPool

app/vmauth: add comment
2026-03-05 19:46:49 +02:00
7 changed files with 612 additions and 29 deletions

View File

@@ -875,12 +875,13 @@ func reloadAuthConfigData(data []byte) (bool, error) {
return false, fmt.Errorf("failed to parse auth config: %w", err)
}
jui, err := parseJWTUsers(ac)
jui, ds, err := parseJWTUsers(ac)
if err != nil {
return false, fmt.Errorf("failed to parse JWT users from auth config: %w", err)
}
jwtc := &jwtCache{
users: jui,
ds: ds,
}
m, err := parseAuthConfigUsers(ac)
@@ -899,6 +900,11 @@ func reloadAuthConfigData(data []byte) (bool, error) {
}
metrics.RegisterSet(ac.ms)
jwtcPrev := jwtAuthCache.Load()
if jwtcPrev != nil {
jwtcPrev.ds.stop()
}
authConfig.Store(ac)
authConfigData.Store(&data)
authUsers.Store(&m)

View File

@@ -1,11 +1,14 @@
package main
import (
"context"
"fmt"
"net/url"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/jwt"
@@ -44,18 +47,34 @@ var urlPathPlaceHolders = []string{
type jwtCache struct {
// users contain UserInfo`s from AuthConfig with JWTConfig set
users []*UserInfo
ds *oidcDiscoverers
}
type JWTConfig struct {
PublicKeys []string `yaml:"public_keys,omitempty"`
PublicKeyFiles []string `yaml:"public_key_files,omitempty"`
SkipVerify bool `yaml:"skip_verify,omitempty"`
PublicKeys []string `yaml:"public_keys,omitempty"`
PublicKeyFiles []string `yaml:"public_key_files,omitempty"`
SkipVerify bool `yaml:"skip_verify,omitempty"`
OIDC *OIDCConfig `yaml:"oidc,omitempty"`
verifierPool *jwt.VerifierPool
// verifierPool is used to verify JWT tokens.
// It is initialized from PublicKeys and/or PublicKeyFiles.
// In this case, it is initialized once at config reload and never updated until next reload
// In case of OIDC, it is initialized on config reload and periodically updated by discovery process.
verifierPool atomic.Pointer[jwt.VerifierPool]
}
func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, *oidcDiscoverers, error) {
jui := make([]*UserInfo, 0, len(ac.Users))
ctx, cancel := context.WithCancel(context.Background())
ds := &oidcDiscoverers{
ds: make(map[string]*oidcDiscoverer),
context: ctx,
cancel: cancel,
wg: &sync.WaitGroup{},
}
for _, ui := range ac.Users {
jwtToken := ui.JWT
if jwtToken == nil {
@@ -63,10 +82,10 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
}
if ui.AuthToken != "" || ui.BearerToken != "" || ui.Username != "" || ui.Password != "" {
return nil, fmt.Errorf("auth_token, bearer_token, username and password cannot be specified if jwt is set")
return nil, nil, fmt.Errorf("auth_token, bearer_token, username and password cannot be specified if jwt is set")
}
if len(jwtToken.PublicKeys) == 0 && len(jwtToken.PublicKeyFiles) == 0 && !jwtToken.SkipVerify {
return nil, fmt.Errorf("jwt must contain at least a single public key, public_key_files or have skip_verify=true")
if len(jwtToken.PublicKeys) == 0 && len(jwtToken.PublicKeyFiles) == 0 && !jwtToken.SkipVerify && jwtToken.OIDC == nil {
return nil, nil, fmt.Errorf("jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true")
}
if len(jwtToken.PublicKeys) > 0 || len(jwtToken.PublicKeyFiles) > 0 {
@@ -75,7 +94,7 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
for i := range jwtToken.PublicKeys {
k, err := jwt.ParseKey([]byte(jwtToken.PublicKeys[i]))
if err != nil {
return nil, err
return nil, nil, err
}
keys = append(keys, k)
}
@@ -83,33 +102,45 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
for _, filePath := range jwtToken.PublicKeyFiles {
keyData, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("cannot read public key from file %q: %w", filePath, err)
return nil, nil, fmt.Errorf("cannot read public key from file %q: %w", filePath, err)
}
k, err := jwt.ParseKey(keyData)
if err != nil {
return nil, fmt.Errorf("cannot parse public key from file %q: %w", filePath, err)
return nil, nil, fmt.Errorf("cannot parse public key from file %q: %w", filePath, err)
}
keys = append(keys, k)
}
vp, err := jwt.NewVerifierPool(keys)
if err != nil {
return nil, err
return nil, nil, err
}
jwtToken.verifierPool = vp
jwtToken.verifierPool.Store(vp)
}
if jwtToken.OIDC != nil {
if len(jwtToken.PublicKeys) > 0 || len(jwtToken.PublicKeyFiles) > 0 || jwtToken.SkipVerify {
return nil, nil, fmt.Errorf("jwt with oidc cannot contain public keys or have skip_verify=true")
}
if jwtToken.OIDC.Issuer == "" {
return nil, nil, fmt.Errorf("oidc issuer cannot be empty")
}
ds.add(jwtToken.OIDC.Issuer, &jwtToken.verifierPool)
}
if err := parseJWTPlaceholdersForUserInfo(&ui, true); err != nil {
return nil, err
return nil, nil, err
}
if err := ui.initURLs(); err != nil {
return nil, err
return nil, nil, err
}
metricLabels, err := ui.getMetricLabels()
if err != nil {
return nil, fmt.Errorf("cannot parse metric_labels: %w", err)
return nil, nil, fmt.Errorf("cannot parse metric_labels: %w", err)
}
ui.requests = ac.ms.GetOrCreateCounter(`vmauth_user_requests_total` + metricLabels)
ui.requestErrors = ac.ms.GetOrCreateCounter(`vmauth_user_request_errors_total` + metricLabels)
@@ -128,7 +159,7 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
rt, err := newRoundTripper(ui.TLSCAFile, ui.TLSCertFile, ui.TLSKeyFile, ui.TLSServerName, ui.TLSInsecureSkipVerify)
if err != nil {
return nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err)
return nil, nil, fmt.Errorf("cannot initialize HTTP RoundTripper: %w", err)
}
ui.rt = rt
@@ -137,10 +168,12 @@ func parseJWTUsers(ac *AuthConfig) ([]*UserInfo, error) {
// TODO: the limitation will be lifted once claim based matching will be implemented
if len(jui) > 1 {
return nil, fmt.Errorf("multiple users with JWT tokens are not supported; found %d users", len(jui))
return nil, nil, fmt.Errorf("multiple users with JWT tokens are not supported; found %d users", len(jui))
}
return jui, nil
ds.start()
return jui, ds, nil
}
func getUserInfoByJWTToken(ats []string) (*UserInfo, *jwt.Token) {
@@ -177,7 +210,31 @@ func getUserInfoByJWTToken(ats []string) (*UserInfo, *jwt.Token) {
return ui, tkn
}
if err := ui.JWT.verifierPool.Verify(tkn); err != nil {
if ui.JWT.OIDC != nil {
// OIDC requires iss claim.
// It must match the discovery issuer URL set in OIDC config.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
if tkn.Issuer() == "" {
continue
}
if tkn.Issuer() != ui.JWT.OIDC.Issuer {
continue
}
}
vp := ui.JWT.verifierPool.Load()
// Could be nil in case OIDC discovery has not succeeded yet.
if vp == nil {
continue
}
// In case of OIDC, current verifier implementation is suboptimal.
// It tries all keys for the same alg until it finds the right one or fails all of them.
// OIDC require using kid claim from the token to choose a proper JWK.
// https://openid.net/specs/openid-connect-core-1_0.html#RotateEncKeys
// https://openid.net/specs/draft-jones-json-web-key-03.html#anchor4
if err := vp.Verify(tkn); err != nil {
if *logInvalidAuthTokens {
logger.Infof("cannot verify jwt token: %s", err)
}

View File

@@ -1,7 +1,10 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
@@ -36,13 +39,14 @@ XOtclIk1uhc03oL9nOQ=
}
return
}
users, err := parseJWTUsers(ac)
users, ds, err := parseJWTUsers(ac)
if err != nil {
if expErr != err.Error() {
t.Fatalf("unexpected error; got\n%q\nwant \n%q", err.Error(), expErr)
}
return
}
ds.stop()
t.Fatalf("expecting non-nil error; got %v", users)
}
@@ -80,28 +84,28 @@ users:
users:
- jwt: {}
url_prefix: http://foo.bar
`, `jwt must contain at least a single public key, public_key_files or have skip_verify=true`)
`, `jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true`)
// jwt public_keys or skip_verify must be set, part 2
f(`
users:
- jwt: {public_keys: null}
url_prefix: http://foo.bar
`, `jwt must contain at least a single public key, public_key_files or have skip_verify=true`)
`, `jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true`)
// jwt public_keys or skip_verify must be set, part 3
f(`
users:
- jwt: {public_keys: []}
url_prefix: http://foo.bar
`, `jwt must contain at least a single public key, public_key_files or have skip_verify=true`)
`, `jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true`)
// jwt public_keys, public_key_files or skip_verify must be set
f(`
users:
- jwt: {public_key_files: []}
url_prefix: http://foo.bar
`, `jwt must contain at least a single public key, public_key_files or have skip_verify=true`)
`, `jwt must contain at least a single public key, public_key_files, oidc or have skip_verify=true`)
// invalid public key, part 1
f(`
@@ -196,6 +200,51 @@ users:
`,
"request header: \"AccountID\" has unsupported placeholder: \"{{ .LogsAccountID }}\", supported values are: {{.MetricsTenant}}, {{.MetricsExtraLabels}}, {{.MetricsExtraFilters}}, {{.LogsAccountID}}, {{.LogsProjectID}}, {{.LogsExtraFilters}}, {{.LogsExtraStreamFilters}}",
)
// oidc is not an object
f(`
users:
- jwt:
oidc: "not an object"
url_prefix: http://foo.bar
`,
"cannot unmarshal AuthConfig data: yaml: unmarshal errors:\n line 4: cannot unmarshal !!str `not an ...` into main.OIDCConfig",
)
// oidc issuer empty
f(`
users:
- jwt:
oidc: {}
url_prefix: http://foo.bar
`,
"oidc issuer cannot be empty",
)
// oidc and public_keys are not allowed
f(fmt.Sprintf(`
users:
- jwt:
public_keys:
- %q
oidc:
issuer: https://example.com
url_prefix: http://foo.bar
`, validRSAPublicKey),
"jwt with oidc cannot contain public keys or have skip_verify=true",
)
// oidc and skip_verify are not allowed
f(`
users:
- jwt:
skip_verify: true
oidc:
issuer: https://example.com
url_prefix: http://foo.bar
`,
"jwt with oidc cannot contain public keys or have skip_verify=true",
)
}
func TestJWTParseAuthConfigSuccess(t *testing.T) {
@@ -225,10 +274,11 @@ XOtclIk1uhc03oL9nOQ=
t.Fatalf("unexpected error: %s", err)
}
jui, err := parseJWTUsers(ac)
jui, ds, err := parseJWTUsers(ac)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer ds.stop()
for _, ui := range jui {
if ui.JWT == nil {
@@ -236,13 +286,13 @@ XOtclIk1uhc03oL9nOQ=
}
if ui.JWT.SkipVerify {
if ui.JWT.verifierPool != nil {
if ui.JWT.verifierPool.Load() != nil {
t.Fatalf("unexpected non-nil verifier pool for skip_verify=true")
}
continue
}
if ui.JWT.verifierPool == nil {
if ui.JWT.verifierPool.Load() == nil {
t.Fatalf("unexpected nil verifier pool for non-empty public keys")
}
}
@@ -333,4 +383,51 @@ users:
- %q
url_prefix: http://foo.bar
`, validECDSAPublicKey, rsaKeyFile))
// oidc stub server
var ipSrv *httptest.Server
ipSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/openid-configuration" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"issuer": ipSrv.URL,
"jwks_uri": fmt.Sprintf("%s/jwks", ipSrv.URL),
})
return
}
if r.URL.Path == "/jwks" {
// reps generated by https://jwkset.com/generate
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`
{
"keys": [
{
"kty": "RSA",
"kid": "f13eee91-f566-4829-80fa-fca847c21f0e",
"d": "Ua1llEFz3LZ05CrK5a2JxKMUEWJGXhBPPF20hHQjzxd1w0IEJK_mhPZQG8dNtBROBNIi1FC9l6QRw-RTnVIVat5Xy4yDFNKXXL3ZLXejOHY8SXrNEIDqQ-cSwIpK9cK7Umib0PcPeEeeAED5mqDH75D8_YssWFF18kLbNB5Z9pZmn6Fshiht7l2Sh4GN-KcReOW6eiQQwckDte3OGmZCRbtEriLWJt5TUGUvfZVIlcclqNMycNB6jGa9E1pO5Up7Ki3ZbI_-6XmRgZPtqnR9oLJ1zn3fj3hYpCXo-zcqLuOu3qxcslsq5igsfBzgGtfIJHY9LfWmHUsaDEa5cAX1gQ",
"n": "xbLXXBTNREk70UCMiqZ53_mTzYh89W-UaPU61GZ-RZ5lYcLgyWOb5mdyRbvJpcgfZpsOeGAUWbk3GkQ4vqn8kUMnnWhUum2Qk9kGubOJGLW6yaURd00j3E-ilQ5xO2R_Hzz8bAojxV8GKdGTQ-iTf8z8nsSHH8kR2SERbNJCFFtwtFU7vyFWyoH4Lmvu2UpICTHFCR9RqwQVjyoKB1JjJ6Dh1L4zPTlsvQEnqoeFQHPYr0QcQSMYXdfPvlt_FiLOAOE89fX_9T2r9WbFAoda3uTRE5_aal0jxUU2cFyeVSIgauNtF07fp422XFb4XPkWQWrdNx0KX53laSIYQ9HOpw",
"e": "AQAB",
"p": "2JT57AD-Q2lamgjgyn0wL7DgYZ3OoCTTrDm5_NHg6h13uDvyIlXSukuUeWm4tzPSDedpstbS7dgXkLw5eQXBHwPYtByTcEZS8Z37CBnhMOOhfo_U1aNIPPanJACvWBgz47-TxHsxW1YhztZqghRoicBZPSSBAj49MgANJ4jF0zc",
"q": "6a4MkeSXJI-ZzQ-bgP8hwJqpLFr0AiNGQcjZMH4Nn4CPGdnGiqqe6flhfLimgbNhbb67B0-8fLIji8zGhGKDL_JSIpAAdmfs2vzeEsY2hScrqVbd1VbfRcRh0J6lsn7obxkbvQthp9sX2DQbeDcEeaFEvd9gDKQSATYEqWo7eBE",
"dp": "haL2yu6Z9RJuuxi7S3YPY33qFZF_y0St71j3L854zzw7gMxMTW9TRWwZQwk-1pv9AmNFzvnK0MNDVyUs-UXZsb932TrApshdqYRnPsppLvdl0GgDVYcYrbUr0IUzrFHSwraVAOlavRbaaXvX4EejcUvkRFvf1nh83fs2Iqy8E-U",
"dq": "Cnf5qC-Ndd3ZDg688LJ9WJuVKJ-Kfu4Fn7zXvgxnn9Wqk4XmFyA9rk21yFidXQIkQz5gMpun3g48-W5bFmMzbVp1w4af_q35NnZNnJm0p5Jxqkxx87TIm9-IYkg5NB3rW87MJ1PzNAnkr5LmCCSu1qQa6Eaxjt9qzxMUcmKH94E",
"qi": "saAeU11iaKHmye3cwCAYkegcyWbXV3xIXEVJtS9Af_yM19UhspwY2VhuwRaajcwYZwtvR9_ITmX9M-ea7uLdd7aDYO1fujC8NGbopeC4Hkr7yb5vTly3pfKf4h-3LwGGUucJUetdz1lmMIYiyuG4_gSf1yIEtPDLKzXiedgEMdI"
}
]
}
`))
return
}
http.NotFound(w, r)
}))
defer ipSrv.Close()
f(`
users:
- jwt:
oidc:
issuer: ` + ipSrv.URL + `
url_prefix: http://foo.bar
`)
}

View File

@@ -12,6 +12,7 @@ import (
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
@@ -1235,7 +1236,142 @@ users:
request,
responseExpected,
)
}
func TestOIDCRequestHandler(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("cannot generate RSA key: %s", err)
}
var oidcSrv *httptest.Server
oidcRespOK := atomic.Bool{}
oidcRespOK.Store(true)
oidcSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(map[string]string{
"issuer": oidcSrv.URL,
"jwks_uri": oidcSrv.URL + "/jwks",
}); err != nil {
panic(fmt.Errorf("cannot write openid-configuration response: %w", err))
}
case "/jwks":
if !oidcRespOK.Load() {
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
// Encode the RSA public key in JWK format (base64url, no padding)
nBytes := privateKey.PublicKey.N.Bytes()
eBytes := big.NewInt(int64(privateKey.PublicKey.E)).Bytes()
jwksBody := fmt.Sprintf(`{"keys":[{"kty":"RSA","kid":%q,"n":%q,"e":%q}]}`,
`test-key-id`,
base64.RawURLEncoding.EncodeToString(nBytes),
base64.RawURLEncoding.EncodeToString(eBytes),
)
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write([]byte(jwksBody)); err != nil {
panic(fmt.Errorf("cannot write jwks response: %w", err))
}
default:
http.NotFound(w, r)
}
}))
defer oidcSrv.Close()
headerJSON, err := json.Marshal(map[string]any{
"alg": "RS256",
"typ": "JWT",
"iss": oidcSrv.URL,
"kid": `test-key-id`,
})
if err != nil {
t.Fatalf("cannot marshal JWT header: %s", err)
}
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
bodyJSON, err := json.Marshal(map[string]any{
"exp": 9999999999,
"iss": oidcSrv.URL,
"vm_access": map[string]any{},
})
if err != nil {
t.Fatalf("cannot marshal JWT body: %s", err)
}
bodyB64 := base64.RawURLEncoding.EncodeToString(bodyJSON)
payload := headerB64 + "." + bodyB64
var signatureB64 string
hash := crypto.SHA256
h := hash.New()
h.Write([]byte(payload))
digest := h.Sum(nil)
signature, err := rsa.SignPKCS1v15(rand.Reader, privateKey, hash, digest)
if err != nil {
t.Fatalf("cannot sign JWT token: %s", err)
}
signatureB64 = base64.RawURLEncoding.EncodeToString(signature)
tkn := payload + "." + signatureB64
backSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer backSrv.Close()
f := func(responseExpected string) {
t.Helper()
cfgStr := `
users:
- jwt:
oidc:
issuer: ` + oidcSrv.URL + `
url_prefix: ` + backSrv.URL + `/
`
cfgOrigP := authConfigData.Load()
if _, err := reloadAuthConfigData([]byte(cfgStr)); err != nil {
t.Fatalf("cannot load config data: %s", err)
}
defer func() {
cfgOrig := []byte("unauthorized_user:\n url_prefix: http://foo/bar")
if cfgOrigP != nil {
cfgOrig = *cfgOrigP
}
if _, err := reloadAuthConfigData(cfgOrig); err != nil {
t.Fatalf("cannot restore original config: %s", err)
}
}()
r := httptest.NewRequest("GET", "http://some-host.com/api/v1/query", nil)
r.Header.Set("Authorization", "Bearer "+tkn)
w := &fakeResponseWriter{}
if !requestHandlerWithInternalRoutes(w, r) {
t.Fatalf("unexpected false returned from requestHandler")
}
if response := w.getResponse(); response != responseExpected {
t.Fatalf("unexpected response\ngot\n%s\nwant\n%s", response, responseExpected)
}
}
// successful
f(`statusCode=200
`)
oidcRespOK.Store(false)
// OIDC server error
f(`statusCode=401
Unauthorized
`)
}
type fakeResponseWriter struct {

259
app/vmauth/oidc.go Normal file
View File

@@ -0,0 +1,259 @@
package main
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/jwt"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/logger"
)
type OIDCConfig struct {
Issuer string `yaml:"issuer"`
}
type oidcDiscoverers struct {
ds map[string]*oidcDiscoverer
context context.Context
cancel func()
wg *sync.WaitGroup
}
func (ds *oidcDiscoverers) stop() {
ds.cancel()
ds.wg.Wait()
}
func (ds *oidcDiscoverers) add(iss string, vp *atomic.Pointer[jwt.VerifierPool]) {
d, ok := ds.ds[iss]
if !ok {
d = &oidcDiscoverer{
issuer: iss,
context: ds.context,
cancel: ds.cancel,
wg: ds.wg,
}
ds.ds[iss] = d
}
d.vps = append(d.vps, vp)
}
func (ds *oidcDiscoverers) start() {
for _, d := range ds.ds {
d.start()
}
}
type oidcDiscoverer struct {
vps []*atomic.Pointer[jwt.VerifierPool]
issuer string
context context.Context
cancel func()
wg *sync.WaitGroup
}
func (d *oidcDiscoverer) start() {
if err := d.refreshVerifierPool(); err != nil {
logger.Errorf("failed to refresh OIDC verifier pool for issuer %q: %v", d.issuer, err)
}
d.wg.Go(func() {
t := time.NewTimer(time.Second * 10)
defer t.Stop()
for {
select {
case <-t.C:
if err := d.refreshVerifierPool(); err != nil {
t.Reset(time.Second * 10)
logger.Errorf("failed to refresh OIDC verifier pool for issuer %q: %v", d.issuer, err)
}
// OIDC may reutrn Cache-Control header with max-age directive.
// It could be used as time rage for next refresh.
// https://openid.net/specs/openid-connect-core-1_0.html#RotateEncKeys
t.Reset(time.Minute * 5)
case <-d.context.Done():
return
}
}
})
}
func (d *oidcDiscoverer) refreshVerifierPool() error {
cfg, err := getOpenIDConfiguration(d.issuer)
if err != nil {
return err
}
// The issuer in the OIDC configuration must match the expected issuer.
// https://openid.net/specs/openid-connect-core-1_0.html#RotateEncKeys
if cfg.Issuer != d.issuer {
return fmt.Errorf("openid configuration issuer %q does not match expected issuer %q", cfg.Issuer, d.issuer)
}
keys, err := fetchJWKs(cfg.JWKsURI)
if err != nil {
return err
}
verifierPool, err := jwt.NewVerifierPool(keys)
if err != nil {
return err
}
for _, vp := range d.vps {
vp.Store(verifierPool)
}
return nil
}
type jwksResponse struct {
Keys []jwk `json:"keys"`
}
// See https://www.rfc-editor.org/rfc/rfc7517 for details.
type jwk struct {
Type string `json:"kty"`
Alg string `json:"alg"`
Use string `json:"use"`
Kid string `json:"kid"`
// RSA keys contents
E string `json:"e"`
N string `json:"n"`
// EC keys contents
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
}
// See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata for details.
type openidConfig struct {
Issuer string `json:"issuer"`
JWKsURI string `json:"jwks_uri"`
}
func fetchJWKs(jwksURI string) ([]any, error) {
resp, err := http.Get(jwksURI)
if err != nil {
return nil, fmt.Errorf("failed to fetch jwks keys from %q: %v", jwksURI, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d when fetching jwks keys from %q", resp.StatusCode, jwksURI)
}
var jwks jwksResponse
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to decode jwks response from %q: %v", jwksURI, err)
}
_ = resp.Body.Close()
keys, err := parseJwksKeys(&jwks)
if err != nil {
return nil, fmt.Errorf("failed to parse jwks keys from %q: %v", jwksURI, err)
}
return keys, nil
}
func getOpenIDConfiguration(issuer string) (openidConfig, error) {
issuer, _ = strings.CutSuffix(issuer, "/")
configURL := fmt.Sprintf("%s/.well-known/openid-configuration", issuer)
resp, err := http.Get(configURL)
if err != nil {
return openidConfig{}, fmt.Errorf("failed to fetch openid config from %q: %v", configURL, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return openidConfig{}, fmt.Errorf("unexpected status code %d when fetching openid config from %q", resp.StatusCode, configURL)
}
var cfg openidConfig
if err := json.NewDecoder(resp.Body).Decode(&cfg); err != nil {
return openidConfig{}, fmt.Errorf("failed to decode openid config from %q: %v", configURL, err)
}
_ = resp.Body.Close()
return cfg, nil
}
func parseJwksKeys(resp *jwksResponse) ([]any, error) {
keys := make(map[string]any)
for _, key := range resp.Keys {
if key.Kid == "" {
return nil, fmt.Errorf("jwks key without kid found")
}
switch key.Type {
case "RSA":
if key.E == "" || key.N == "" {
return nil, fmt.Errorf("jwks key without e or n found")
}
e, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
return nil, fmt.Errorf("failed to decode jwks key e: %w", err)
}
n, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
return nil, fmt.Errorf("failed to decode jwks key n: %w", err)
}
keys[key.Kid] = &rsa.PublicKey{
E: int(big.NewInt(0).SetBytes(e).Int64()),
N: big.NewInt(0).SetBytes(n),
}
case "EC":
if key.Crv == "" || key.X == "" || key.Y == "" {
return nil, fmt.Errorf("jwks key without crv or x or y found")
}
x, err := base64.RawURLEncoding.DecodeString(key.X)
if err != nil {
return nil, fmt.Errorf("failed to decode jwks key x: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(key.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode jwks key y: %w", err)
}
var curve elliptic.Curve
switch key.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported jwks key crv %q found", key.Crv)
}
keys[key.Kid] = &ecdsa.PublicKey{
Curve: curve,
X: big.NewInt(0).SetBytes(x),
Y: big.NewInt(0).SetBytes(y),
}
}
}
keysValues := make([]any, 0)
for _, key := range keys {
keysValues = append(keysValues, key)
}
return keysValues, nil
}

View File

@@ -99,6 +99,7 @@ type body struct {
Exp int64 `json:"exp"`
// issued at time unix_ts
Iat int64 `json:"iat"`
Iss string `json:"iss"`
Jti string `json:"jti,omitempty"`
Scope string `json:"scope,omitempty"`
vmAccessClaim VMAccessClaim
@@ -138,6 +139,14 @@ func (b *body) parse(src string) error {
return fmt.Errorf("cannot parse `iat` field: %w", err)
}
}
if issObject := jv.Get("iss"); issObject != nil {
bIss, err := issObject.StringBytes()
if err != nil {
return fmt.Errorf("cannot parse `iss` field: %w", err)
}
b.Iss = bytesutil.ToUnsafeString(bIss)
}
vaObject := jv.Get("vm_access")
if vaObject == nil {
return ErrVMAccessFieldMissing
@@ -265,6 +274,11 @@ func (t *Token) HasClaims(claims map[string]string) bool {
return true
}
// Issuer returns `iss` claim value from token body
func (t *Token) Issuer() string {
return t.body.Iss
}
// VMAccess return a reference to the VMAccessClaim
// all data are valid until Token is reachable
func (t *Token) VMAccess() *VMAccessClaim {

View File

@@ -305,6 +305,13 @@ func TestParseJWTBody_Failure(t *testing.T) {
`unexpected type for key="logs_extra_stream_filters", got: string, want: array string`,
true,
)
// invalid iss claim value type
f(
`{"iss": {}, "vm_access": {}}`,
"cannot parse `iss` field: value doesn't contain string; it contains object",
true,
)
}
func TestParseJWTBody_Success(t *testing.T) {
@@ -326,6 +333,9 @@ func TestParseJWTBody_Success(t *testing.T) {
if result.Iat != resultExpected.Iat {
t.Fatalf("unexpected Iat; got %d; want %d", result.Iat, resultExpected.Iat)
}
if result.Iss != resultExpected.Iss {
t.Fatalf("unexpected Iss; got %s; want %s", result.Iss, resultExpected.Iss)
}
if result.Scope != resultExpected.Scope {
t.Fatalf("unexpected scope; got %q; want %q", result.Scope, resultExpected.Scope)
}
@@ -349,6 +359,10 @@ func TestParseJWTBody_Success(t *testing.T) {
f(`{"vm_access": {"tenant_id": {}}}`, &body{
vmAccessClaim: VMAccessClaim{},
})
f(`{"iss": "theIssuer", "vm_access": {"tenant_id": {}}}`, &body{
Iss: "theIssuer",
vmAccessClaim: VMAccessClaim{},
})
f(
`