From d89470a33f339dc0d4d7b0c41c045c38ba103dd5 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Thu, 15 Oct 2020 12:39:07 +0200 Subject: [PATCH] improve userinfo token handling --- pkg/oidc/token.go | 6 ++++++ pkg/op/userinfo.go | 32 ++++++++++++++++++++++---------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 2a8c0ad..99f18c7 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -24,6 +24,7 @@ type Tokens struct { type AccessTokenClaims interface { Claims + GetSubject() string GetTokenID() string SetPrivateClaims(map[string]interface{}) } @@ -128,6 +129,11 @@ func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgori a.signatureAlg = algorithm } +//GetSubject implements the AccessTokenClaims interface +func (a *accessTokenClaims) GetSubject() string { + return a.Subject +} + //GetTokenID implements the AccessTokenClaims interface func (a *accessTokenClaims) GetTokenID() string { return a.JWTID diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index f1991ac..d5ca68e 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -1,6 +1,7 @@ package op import ( + "context" "errors" "net/http" "strings" @@ -28,17 +29,12 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP http.Error(w, "access token missing", http.StatusUnauthorized) return } - tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) - if err != nil { - accessTokenClaims, err := VerifyAccessToken(r.Context(), accessToken, userinfoProvider.AccessTokenVerifier()) - if err != nil { - http.Error(w, "access token invalid", http.StatusUnauthorized) - return - } - tokenID = accessTokenClaims.GetTokenID() + tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken) + if !ok { + http.Error(w, "access token invalid", http.StatusUnauthorized) + return } - splittedToken := strings.Split(tokenIDSubject, ":") - info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), splittedToken[0], splittedToken[1], r.Header.Get("origin")) + info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, subject, r.Header.Get("origin")) if err != nil { w.WriteHeader(http.StatusForbidden) utils.MarshalJSON(w, err) @@ -67,3 +63,19 @@ func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) { } return req.AccessToken, nil } + +func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { + tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) + if err == nil { + splittedToken := strings.Split(tokenIDSubject, ":") + if len(splittedToken) != 2 { + return "", "", false + } + return splittedToken[0], splittedToken[1], true + } + accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier()) + if err != nil { + return "", "", false + } + return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true +}