interface
This commit is contained in:
parent
80eeee2de2
commit
988a556fa9
10 changed files with 131 additions and 76 deletions
|
@ -58,6 +58,24 @@ type AuthRequest struct {
|
|||
ACRValues []string `schema:"acr_values"`
|
||||
}
|
||||
|
||||
// func (a *AuthRequest) GetID() string {
|
||||
// return a.ID
|
||||
// }
|
||||
|
||||
// func (a *AuthRequest) GetClientID() string {
|
||||
// return a.ClientID
|
||||
// }
|
||||
|
||||
func (a *AuthRequest) GetRedirectURI() string {
|
||||
return a.RedirectURI
|
||||
}
|
||||
func (a *AuthRequest) GetResponseType() ResponseType {
|
||||
return a.ResponseType
|
||||
}
|
||||
func (a *AuthRequest) GetState() string {
|
||||
return a.State
|
||||
}
|
||||
|
||||
type TokenRequest interface {
|
||||
// GrantType GrantType `schema:"grant_type"`
|
||||
GrantType() GrantType
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -82,3 +87,27 @@ type Tokens struct {
|
|||
*oauth2.Token
|
||||
IDTokenClaims *IDTokenClaims
|
||||
}
|
||||
|
||||
func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||
tokenHash, err := getHashAlgorithm(sigAlgorithm)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
|
||||
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
|
||||
return base64.RawURLEncoding.EncodeToString(sum), nil
|
||||
}
|
||||
|
||||
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
|
||||
switch sigAlgorithm {
|
||||
case jose.RS256, jose.ES256, jose.PS256:
|
||||
return sha256.New(), nil
|
||||
case jose.RS384, jose.ES384, jose.PS384:
|
||||
return sha512.New384(), nil
|
||||
case jose.RS512, jose.ES512, jose.PS512:
|
||||
return sha512.New(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/schema"
|
||||
|
@ -19,7 +20,7 @@ type Authorizer interface {
|
|||
Decoder() *schema.Decoder
|
||||
Encoder() *schema.Encoder
|
||||
Signe() u.Signer
|
||||
ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
|
||||
// ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
|
||||
}
|
||||
|
||||
// type Signer interface {
|
||||
|
@ -32,7 +33,7 @@ type ValidationAuthorizer interface {
|
|||
}
|
||||
|
||||
// type errorHandler func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
|
||||
type callbackHandler func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request)
|
||||
// type callbackHandler func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request)
|
||||
|
||||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
err := r.ParseForm()
|
||||
|
@ -58,18 +59,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
return
|
||||
}
|
||||
|
||||
err = authorizer.Storage().CreateAuthRequest(authReq)
|
||||
req, err := authorizer.Storage().CreateAuthRequest(authReq)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := authorizer.Storage().GetClientByClientID(authReq.ClientID)
|
||||
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err)
|
||||
AuthRequestError(w, r, req, err)
|
||||
return
|
||||
}
|
||||
RedirectToLogin(authReq, client, w, r)
|
||||
RedirectToLogin(req, client, w, r)
|
||||
}
|
||||
|
||||
func ValidateAuthRequest(authReq *oidc.AuthRequest, storage u.Storage) error {
|
||||
|
@ -115,15 +116,15 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons
|
|||
return nil
|
||||
}
|
||||
if responseType == oidc.ResponseTypeCode {
|
||||
if strings.HasPrefix(uri, "http://") && oidc.IsConfidentialType(client) {
|
||||
if strings.HasPrefix(uri, "http://") && u.IsConfidentialType(client) {
|
||||
return nil
|
||||
}
|
||||
if client.ApplicationType() == oidc.ApplicationTypeNative {
|
||||
if client.ApplicationType() == u.ApplicationTypeNative {
|
||||
return nil
|
||||
}
|
||||
return ErrInvalidRequest("redirect_uri not allowed 2")
|
||||
} else {
|
||||
if client.ApplicationType() != oidc.ApplicationTypeNative {
|
||||
if client.ApplicationType() != u.ApplicationTypeNative {
|
||||
return ErrInvalidRequest("redirect_uri not allowed 3")
|
||||
}
|
||||
if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) {
|
||||
|
@ -133,8 +134,8 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons
|
|||
return nil
|
||||
}
|
||||
|
||||
func RedirectToLogin(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
|
||||
login := client.LoginURL(authReq.ID)
|
||||
func RedirectToLogin(authReq u.AuthRequest, client u.Client, w http.ResponseWriter, r *http.Request) {
|
||||
login := client.LoginURL(authReq.GetID())
|
||||
http.Redirect(w, r, login, http.StatusFound)
|
||||
}
|
||||
|
||||
|
@ -150,20 +151,20 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
|||
AuthResponse(authReq, authorizer, w, r)
|
||||
}
|
||||
|
||||
func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
func AuthResponse(authReq u.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
var callback string
|
||||
if authReq.ResponseType == oidc.ResponseTypeCode {
|
||||
callback = fmt.Sprintf("%s?code=%s", authReq.RedirectURI, "test")
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test")
|
||||
} else {
|
||||
var accessToken string
|
||||
var err error
|
||||
if authReq.ResponseType != oidc.ResponseTypeIDTokenOnly {
|
||||
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
|
||||
accessToken, err = CreateAccessToken()
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
}
|
||||
idToken, err := CreateIDToken(authReq, accessToken, authorizer.Signe())
|
||||
idToken, err := CreateIDToken("", authReq, accessToken, time.Now(), time.Now(), "", authorizer.Signe())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
|
@ -175,7 +176,7 @@ func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.Respo
|
|||
values := make(map[string][]string)
|
||||
authorizer.Encoder().Encode(resp, values)
|
||||
v := url.Values(values)
|
||||
callback = fmt.Sprintf("%s#%s", authReq.RedirectURI, v.Encode())
|
||||
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode())
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
|
|
@ -168,9 +168,9 @@ func (p *DefaultOP) Signe() u.Signer {
|
|||
// return
|
||||
}
|
||||
|
||||
func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
|
||||
return AuthRequestError
|
||||
}
|
||||
// func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
|
||||
// return AuthRequestError
|
||||
// }
|
||||
|
||||
func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||
Authorize(w, r, p)
|
||||
|
|
|
@ -3,6 +3,8 @@ package op
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/op/u"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
@ -14,17 +16,17 @@ const (
|
|||
|
||||
type errorType string
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq u.ErrAuthRequest, err error) {
|
||||
if authReq == nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if authReq.RedirectURI == "" {
|
||||
if authReq.GetRedirectURI() == "" {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := authReq.RedirectURI
|
||||
if authReq.ResponseType == oidc.ResponseTypeCode {
|
||||
url := authReq.GetRedirectURI()
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
url += "?"
|
||||
} else {
|
||||
url += "#"
|
||||
|
@ -42,8 +44,8 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.Auth
|
|||
if description != "" {
|
||||
url += "&error_description=" + description
|
||||
}
|
||||
if authReq.State != "" {
|
||||
url += "&state=" + authReq.State
|
||||
if authReq.GetState() != "" {
|
||||
url += "&state=" + authReq.GetState()
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
@ -77,17 +79,17 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest) {
|
||||
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq u.AuthRequest) {
|
||||
if authReq == nil {
|
||||
http.Error(w, e.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if authReq.RedirectURI == "" {
|
||||
if authReq.GetRedirectURI() == "" {
|
||||
http.Error(w, e.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := authReq.RedirectURI
|
||||
if authReq.ResponseType == oidc.ResponseTypeCode {
|
||||
url := authReq.GetRedirectURI()
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
url += "?"
|
||||
} else {
|
||||
url += "#"
|
||||
|
@ -96,8 +98,8 @@ func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request,
|
|||
if e.Description != "" {
|
||||
url += "&error_description=" + e.Description
|
||||
}
|
||||
if authReq.State != "" {
|
||||
url += "&state=" + authReq.State
|
||||
if authReq.GetState() != "" {
|
||||
url += "&state=" + authReq.GetState()
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
|
|
@ -25,4 +25,5 @@ require (
|
|||
github.com/gorilla/schema v1.1.0
|
||||
github.com/stretchr/testify v1.4.0
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
|
||||
gopkg.in/square/go-jose.v2 v2.4.0
|
||||
)
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/op/u"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
|
||||
|
@ -52,7 +54,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage u.Storage, dec
|
|||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
err = storage.DeleteAuthRequestAndCode(authReq.ID, tokenReq.Code)
|
||||
err = storage.DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
|
@ -62,7 +64,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage u.Storage, dec
|
|||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
idToken, err := CreateIDToken(nil, "", nil)
|
||||
idToken, err := CreateIDToken("", authReq, "", time.Now(), time.Now(), "", nil)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
|
@ -79,28 +81,31 @@ func CreateAccessToken() (string, error) {
|
|||
return "accessToken", nil
|
||||
}
|
||||
|
||||
func CreateIDToken(authReq *oidc.AuthRequest, atHash string, signer u.Signer) (string, error) {
|
||||
var issuer, sub, acr string
|
||||
var aud, amr []string
|
||||
var exp, iat, authTime time.Time
|
||||
|
||||
func CreateIDToken(issuer string, authReq u.AuthRequest, sub string, exp, authTime time.Time, accessToken string, signer u.Signer) (string, error) {
|
||||
var err error
|
||||
claims := &oidc.IDTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: sub,
|
||||
Audiences: aud,
|
||||
Subject: authReq.GetSubject(),
|
||||
Audiences: authReq.GetAudience(),
|
||||
Expiration: exp,
|
||||
IssuedAt: iat,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
AuthTime: authTime,
|
||||
Nonce: authReq.Nonce,
|
||||
AuthenticationContextClassReference: acr,
|
||||
AuthenticationMethodsReferences: amr,
|
||||
AuthorizedParty: authReq.ClientID,
|
||||
AccessTokenHash: atHash,
|
||||
Nonce: authReq.GetNonce(),
|
||||
AuthenticationContextClassReference: authReq.GetACR(),
|
||||
AuthenticationMethodsReferences: authReq.GetAMR(),
|
||||
AuthorizedParty: authReq.GetClientID(),
|
||||
}
|
||||
if accessToken != "" {
|
||||
var alg jose.SignatureAlgorithm
|
||||
claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, alg) //TODO: alg
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return signer.Sign(claims)
|
||||
}
|
||||
|
||||
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage u.Storage) (oidc.Client, error) {
|
||||
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage u.Storage) (u.Client, error) {
|
||||
if tokenReq.ClientID == "" {
|
||||
clientID, clientSecret, ok := r.BasicAuth()
|
||||
if ok {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package oidc
|
||||
package u
|
||||
|
||||
type Client interface {
|
||||
RedirectURIs() []string
|
|
@ -3,11 +3,30 @@ package u
|
|||
import "github.com/caos/oidc/pkg/oidc"
|
||||
|
||||
type Storage interface {
|
||||
CreateAuthRequest(*oidc.AuthRequest) error
|
||||
GetClientByClientID(string) (oidc.Client, error)
|
||||
AuthRequestByID(string) (*oidc.AuthRequest, error)
|
||||
AuthRequestByCode(oidc.Client, string, string) (*oidc.AuthRequest, error)
|
||||
AuthorizeClientIDSecret(string, string) (oidc.Client, error)
|
||||
AuthorizeClientIDCodeVerifier(string, string) (oidc.Client, error)
|
||||
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
|
||||
GetClientByClientID(string) (Client, error)
|
||||
AuthRequestByID(string) (AuthRequest, error)
|
||||
AuthRequestByCode(Client, string, string) (AuthRequest, error)
|
||||
AuthorizeClientIDSecret(string, string) (Client, error)
|
||||
AuthorizeClientIDCodeVerifier(string, string) (Client, error)
|
||||
DeleteAuthRequestAndCode(string, string) error
|
||||
}
|
||||
|
||||
type ErrAuthRequest interface {
|
||||
GetRedirectURI() string
|
||||
GetResponseType() oidc.ResponseType
|
||||
GetState() string
|
||||
}
|
||||
|
||||
type AuthRequest interface {
|
||||
GetID() string
|
||||
GetACR() string
|
||||
GetAMR() []string
|
||||
GetAudience() []string
|
||||
GetClientID() string
|
||||
GetNonce() string
|
||||
GetRedirectURI() string
|
||||
GetResponseType() oidc.ResponseType
|
||||
GetState() string
|
||||
GetSubject() string
|
||||
}
|
||||
|
|
|
@ -3,12 +3,9 @@ package rp
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -446,29 +443,12 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor
|
|||
return nil //TODO: return error
|
||||
}
|
||||
|
||||
tokenHash, err := getHashAlgorithm(sigAlgorithm)
|
||||
actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
|
||||
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
|
||||
actual := base64.RawURLEncoding.EncodeToString(sum)
|
||||
if actual != atHash {
|
||||
return nil //TODO: error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
|
||||
switch sigAlgorithm {
|
||||
case jose.RS256, jose.ES256, jose.PS256:
|
||||
return sha256.New(), nil
|
||||
case jose.RS384, jose.ES384, jose.PS384:
|
||||
return sha512.New384(), nil
|
||||
case jose.RS512, jose.ES512, jose.PS512:
|
||||
return sha512.New(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue