diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index d4927ab..d3b8505 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -58,6 +58,24 @@ type AuthRequest struct { ACRValues []string `schema:"acr_values"` } +// func (a *AuthRequest) GetID() string { +// return a.ID +// } + +// func (a *AuthRequest) GetClientID() string { +// return a.ClientID +// } + +func (a *AuthRequest) GetRedirectURI() string { + return a.RedirectURI +} +func (a *AuthRequest) GetResponseType() ResponseType { + return a.ResponseType +} +func (a *AuthRequest) GetState() string { + return a.State +} + type TokenRequest interface { // GrantType GrantType `schema:"grant_type"` GrantType() GrantType diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 5e2b659..4a71847 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -1,7 +1,12 @@ package oidc import ( + "crypto/sha256" + "crypto/sha512" + "encoding/base64" "encoding/json" + "fmt" + "hash" "time" "golang.org/x/oauth2" @@ -82,3 +87,27 @@ type Tokens struct { *oauth2.Token IDTokenClaims *IDTokenClaims } + +func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { + tokenHash, err := getHashAlgorithm(sigAlgorithm) + if err != nil { + return "", err + } + + tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error + sum := tokenHash.Sum(nil)[:tokenHash.Size()/2] + return base64.RawURLEncoding.EncodeToString(sum), nil +} + +func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { + switch sigAlgorithm { + case jose.RS256, jose.ES256, jose.PS256: + return sha256.New(), nil + case jose.RS384, jose.ES384, jose.PS384: + return sha512.New384(), nil + case jose.RS512, jose.ES512, jose.PS512: + return sha512.New(), nil + default: + return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm) + } +} diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index e580669..5c5ed30 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -5,6 +5,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/gorilla/mux" "github.com/gorilla/schema" @@ -19,7 +20,7 @@ type Authorizer interface { Decoder() *schema.Decoder Encoder() *schema.Encoder Signe() u.Signer - ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) + // ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) } // type Signer interface { @@ -32,7 +33,7 @@ type ValidationAuthorizer interface { } // type errorHandler func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) -type callbackHandler func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) +// type callbackHandler func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { err := r.ParseForm() @@ -58,18 +59,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { return } - err = authorizer.Storage().CreateAuthRequest(authReq) + req, err := authorizer.Storage().CreateAuthRequest(authReq) if err != nil { AuthRequestError(w, r, authReq, err) return } - client, err := authorizer.Storage().GetClientByClientID(authReq.ClientID) + client, err := authorizer.Storage().GetClientByClientID(req.GetClientID()) if err != nil { - AuthRequestError(w, r, authReq, err) + AuthRequestError(w, r, req, err) return } - RedirectToLogin(authReq, client, w, r) + RedirectToLogin(req, client, w, r) } func ValidateAuthRequest(authReq *oidc.AuthRequest, storage u.Storage) error { @@ -115,15 +116,15 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons return nil } if responseType == oidc.ResponseTypeCode { - if strings.HasPrefix(uri, "http://") && oidc.IsConfidentialType(client) { + if strings.HasPrefix(uri, "http://") && u.IsConfidentialType(client) { return nil } - if client.ApplicationType() == oidc.ApplicationTypeNative { + if client.ApplicationType() == u.ApplicationTypeNative { return nil } return ErrInvalidRequest("redirect_uri not allowed 2") } else { - if client.ApplicationType() != oidc.ApplicationTypeNative { + if client.ApplicationType() != u.ApplicationTypeNative { return ErrInvalidRequest("redirect_uri not allowed 3") } if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) { @@ -133,8 +134,8 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons return nil } -func RedirectToLogin(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { - login := client.LoginURL(authReq.ID) +func RedirectToLogin(authReq u.AuthRequest, client u.Client, w http.ResponseWriter, r *http.Request) { + login := client.LoginURL(authReq.GetID()) http.Redirect(w, r, login, http.StatusFound) } @@ -150,20 +151,20 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author AuthResponse(authReq, authorizer, w, r) } -func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { +func AuthResponse(authReq u.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { var callback string - if authReq.ResponseType == oidc.ResponseTypeCode { - callback = fmt.Sprintf("%s?code=%s", authReq.RedirectURI, "test") + if authReq.GetResponseType() == oidc.ResponseTypeCode { + callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test") } else { var accessToken string var err error - if authReq.ResponseType != oidc.ResponseTypeIDTokenOnly { + if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly { accessToken, err = CreateAccessToken() if err != nil { } } - idToken, err := CreateIDToken(authReq, accessToken, authorizer.Signe()) + idToken, err := CreateIDToken("", authReq, accessToken, time.Now(), time.Now(), "", authorizer.Signe()) if err != nil { } @@ -175,7 +176,7 @@ func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.Respo values := make(map[string][]string) authorizer.Encoder().Encode(resp, values) v := url.Values(values) - callback = fmt.Sprintf("%s#%s", authReq.RedirectURI, v.Encode()) + callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode()) } http.Redirect(w, r, callback, http.StatusFound) } diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index b3aa54d..2a6d925 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -168,9 +168,9 @@ func (p *DefaultOP) Signe() u.Signer { // return } -func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { - return AuthRequestError -} +// func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { +// return AuthRequestError +// } func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) { Authorize(w, r, p) diff --git a/pkg/op/error.go b/pkg/op/error.go index d794518..2f7252d 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -3,6 +3,8 @@ package op import ( "net/http" + "github.com/caos/oidc/pkg/op/u" + "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" ) @@ -14,17 +16,17 @@ const ( type errorType string -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq u.ErrAuthRequest, err error) { if authReq == nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - if authReq.RedirectURI == "" { + if authReq.GetRedirectURI() == "" { http.Error(w, err.Error(), http.StatusBadRequest) return } - url := authReq.RedirectURI - if authReq.ResponseType == oidc.ResponseTypeCode { + url := authReq.GetRedirectURI() + if authReq.GetResponseType() == oidc.ResponseTypeCode { url += "?" } else { url += "#" @@ -42,8 +44,8 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.Auth if description != "" { url += "&error_description=" + description } - if authReq.State != "" { - url += "&state=" + authReq.State + if authReq.GetState() != "" { + url += "&state=" + authReq.GetState() } http.Redirect(w, r, url, http.StatusFound) } @@ -77,17 +79,17 @@ var ( } ) -func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest) { +func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq u.AuthRequest) { if authReq == nil { http.Error(w, e.Error(), http.StatusBadRequest) return } - if authReq.RedirectURI == "" { + if authReq.GetRedirectURI() == "" { http.Error(w, e.Error(), http.StatusBadRequest) return } - url := authReq.RedirectURI - if authReq.ResponseType == oidc.ResponseTypeCode { + url := authReq.GetRedirectURI() + if authReq.GetResponseType() == oidc.ResponseTypeCode { url += "?" } else { url += "#" @@ -96,8 +98,8 @@ func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, if e.Description != "" { url += "&error_description=" + e.Description } - if authReq.State != "" { - url += "&state=" + authReq.State + if authReq.GetState() != "" { + url += "&state=" + authReq.GetState() } http.Redirect(w, r, url, http.StatusFound) } diff --git a/pkg/op/go.mod b/pkg/op/go.mod index 418ba38..16a68a7 100644 --- a/pkg/op/go.mod +++ b/pkg/op/go.mod @@ -25,4 +25,5 @@ require ( github.com/gorilla/schema v1.1.0 github.com/stretchr/testify v1.4.0 golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 + gopkg.in/square/go-jose.v2 v2.4.0 ) diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index cd99aab..570a97e 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -5,6 +5,8 @@ import ( "net/http" "time" + "gopkg.in/square/go-jose.v2" + "github.com/caos/oidc/pkg/op/u" "github.com/caos/oidc/pkg/utils" @@ -52,7 +54,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage u.Storage, dec ExchangeRequestError(w, r, err) return } - err = storage.DeleteAuthRequestAndCode(authReq.ID, tokenReq.Code) + err = storage.DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code) if err != nil { ExchangeRequestError(w, r, err) return @@ -62,7 +64,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage u.Storage, dec ExchangeRequestError(w, r, err) return } - idToken, err := CreateIDToken(nil, "", nil) + idToken, err := CreateIDToken("", authReq, "", time.Now(), time.Now(), "", nil) if err != nil { ExchangeRequestError(w, r, err) return @@ -79,28 +81,31 @@ func CreateAccessToken() (string, error) { return "accessToken", nil } -func CreateIDToken(authReq *oidc.AuthRequest, atHash string, signer u.Signer) (string, error) { - var issuer, sub, acr string - var aud, amr []string - var exp, iat, authTime time.Time - +func CreateIDToken(issuer string, authReq u.AuthRequest, sub string, exp, authTime time.Time, accessToken string, signer u.Signer) (string, error) { + var err error claims := &oidc.IDTokenClaims{ Issuer: issuer, - Subject: sub, - Audiences: aud, + Subject: authReq.GetSubject(), + Audiences: authReq.GetAudience(), Expiration: exp, - IssuedAt: iat, + IssuedAt: time.Now().UTC(), AuthTime: authTime, - Nonce: authReq.Nonce, - AuthenticationContextClassReference: acr, - AuthenticationMethodsReferences: amr, - AuthorizedParty: authReq.ClientID, - AccessTokenHash: atHash, + Nonce: authReq.GetNonce(), + AuthenticationContextClassReference: authReq.GetACR(), + AuthenticationMethodsReferences: authReq.GetAMR(), + AuthorizedParty: authReq.GetClientID(), + } + if accessToken != "" { + var alg jose.SignatureAlgorithm + claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, alg) //TODO: alg + if err != nil { + return "", err + } } return signer.Sign(claims) } -func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage u.Storage) (oidc.Client, error) { +func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage u.Storage) (u.Client, error) { if tokenReq.ClientID == "" { clientID, clientSecret, ok := r.BasicAuth() if ok { diff --git a/pkg/oidc/client.go b/pkg/op/u/client.go similarity index 97% rename from pkg/oidc/client.go rename to pkg/op/u/client.go index fe243b2..ed37927 100644 --- a/pkg/oidc/client.go +++ b/pkg/op/u/client.go @@ -1,4 +1,4 @@ -package oidc +package u type Client interface { RedirectURIs() []string diff --git a/pkg/op/u/storage.go b/pkg/op/u/storage.go index a446ce6..ed7bfdf 100644 --- a/pkg/op/u/storage.go +++ b/pkg/op/u/storage.go @@ -3,11 +3,30 @@ package u import "github.com/caos/oidc/pkg/oidc" type Storage interface { - CreateAuthRequest(*oidc.AuthRequest) error - GetClientByClientID(string) (oidc.Client, error) - AuthRequestByID(string) (*oidc.AuthRequest, error) - AuthRequestByCode(oidc.Client, string, string) (*oidc.AuthRequest, error) - AuthorizeClientIDSecret(string, string) (oidc.Client, error) - AuthorizeClientIDCodeVerifier(string, string) (oidc.Client, error) + CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error) + GetClientByClientID(string) (Client, error) + AuthRequestByID(string) (AuthRequest, error) + AuthRequestByCode(Client, string, string) (AuthRequest, error) + AuthorizeClientIDSecret(string, string) (Client, error) + AuthorizeClientIDCodeVerifier(string, string) (Client, error) DeleteAuthRequestAndCode(string, string) error } + +type ErrAuthRequest interface { + GetRedirectURI() string + GetResponseType() oidc.ResponseType + GetState() string +} + +type AuthRequest interface { + GetID() string + GetACR() string + GetAMR() []string + GetAudience() []string + GetClientID() string + GetNonce() string + GetRedirectURI() string + GetResponseType() oidc.ResponseType + GetState() string + GetSubject() string +} diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go index 4d429e0..0df4900 100644 --- a/pkg/rp/default_verifier.go +++ b/pkg/rp/default_verifier.go @@ -3,12 +3,9 @@ package rp import ( "bytes" "context" - "crypto/sha256" - "crypto/sha512" "encoding/base64" "encoding/json" "fmt" - "hash" "strings" "time" @@ -446,29 +443,12 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor return nil //TODO: return error } - tokenHash, err := getHashAlgorithm(sigAlgorithm) + actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm) if err != nil { return err } - - tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error - sum := tokenHash.Sum(nil)[:tokenHash.Size()/2] - actual := base64.RawURLEncoding.EncodeToString(sum) if actual != atHash { return nil //TODO: error } return nil } - -func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { - switch sigAlgorithm { - case jose.RS256, jose.ES256, jose.PS256: - return sha256.New(), nil - case jose.RS384, jose.ES384, jose.PS384: - return sha512.New384(), nil - case jose.RS512, jose.ES512, jose.PS512: - return sha512.New(), nil - default: - return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm) - } -}