interface

This commit is contained in:
Livio Amstutz 2019-11-28 15:29:19 +01:00
parent 80eeee2de2
commit 988a556fa9
10 changed files with 131 additions and 76 deletions

View file

@ -58,6 +58,24 @@ type AuthRequest struct {
ACRValues []string `schema:"acr_values"`
}
// func (a *AuthRequest) GetID() string {
// return a.ID
// }
// func (a *AuthRequest) GetClientID() string {
// return a.ClientID
// }
func (a *AuthRequest) GetRedirectURI() string {
return a.RedirectURI
}
func (a *AuthRequest) GetResponseType() ResponseType {
return a.ResponseType
}
func (a *AuthRequest) GetState() string {
return a.State
}
type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"`
GrantType() GrantType

View file

@ -1,7 +1,12 @@
package oidc
import (
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"fmt"
"hash"
"time"
"golang.org/x/oauth2"
@ -82,3 +87,27 @@ type Tokens struct {
*oauth2.Token
IDTokenClaims *IDTokenClaims
}
func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
tokenHash, err := getHashAlgorithm(sigAlgorithm)
if err != nil {
return "", err
}
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
return base64.RawURLEncoding.EncodeToString(sum), nil
}
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
switch sigAlgorithm {
case jose.RS256, jose.ES256, jose.PS256:
return sha256.New(), nil
case jose.RS384, jose.ES384, jose.PS384:
return sha512.New384(), nil
case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil
default:
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
}
}

View file

@ -5,6 +5,7 @@ import (
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/schema"
@ -19,7 +20,7 @@ type Authorizer interface {
Decoder() *schema.Decoder
Encoder() *schema.Encoder
Signe() u.Signer
ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
// ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
}
// type Signer interface {
@ -32,7 +33,7 @@ type ValidationAuthorizer interface {
}
// type errorHandler func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
type callbackHandler func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request)
// type callbackHandler func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request)
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err := r.ParseForm()
@ -58,18 +59,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
return
}
err = authorizer.Storage().CreateAuthRequest(authReq)
req, err := authorizer.Storage().CreateAuthRequest(authReq)
if err != nil {
AuthRequestError(w, r, authReq, err)
return
}
client, err := authorizer.Storage().GetClientByClientID(authReq.ClientID)
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
if err != nil {
AuthRequestError(w, r, authReq, err)
AuthRequestError(w, r, req, err)
return
}
RedirectToLogin(authReq, client, w, r)
RedirectToLogin(req, client, w, r)
}
func ValidateAuthRequest(authReq *oidc.AuthRequest, storage u.Storage) error {
@ -115,15 +116,15 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons
return nil
}
if responseType == oidc.ResponseTypeCode {
if strings.HasPrefix(uri, "http://") && oidc.IsConfidentialType(client) {
if strings.HasPrefix(uri, "http://") && u.IsConfidentialType(client) {
return nil
}
if client.ApplicationType() == oidc.ApplicationTypeNative {
if client.ApplicationType() == u.ApplicationTypeNative {
return nil
}
return ErrInvalidRequest("redirect_uri not allowed 2")
} else {
if client.ApplicationType() != oidc.ApplicationTypeNative {
if client.ApplicationType() != u.ApplicationTypeNative {
return ErrInvalidRequest("redirect_uri not allowed 3")
}
if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) {
@ -133,8 +134,8 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons
return nil
}
func RedirectToLogin(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
login := client.LoginURL(authReq.ID)
func RedirectToLogin(authReq u.AuthRequest, client u.Client, w http.ResponseWriter, r *http.Request) {
login := client.LoginURL(authReq.GetID())
http.Redirect(w, r, login, http.StatusFound)
}
@ -150,20 +151,20 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
AuthResponse(authReq, authorizer, w, r)
}
func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
func AuthResponse(authReq u.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
var callback string
if authReq.ResponseType == oidc.ResponseTypeCode {
callback = fmt.Sprintf("%s?code=%s", authReq.RedirectURI, "test")
if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test")
} else {
var accessToken string
var err error
if authReq.ResponseType != oidc.ResponseTypeIDTokenOnly {
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
accessToken, err = CreateAccessToken()
if err != nil {
}
}
idToken, err := CreateIDToken(authReq, accessToken, authorizer.Signe())
idToken, err := CreateIDToken("", authReq, accessToken, time.Now(), time.Now(), "", authorizer.Signe())
if err != nil {
}
@ -175,7 +176,7 @@ func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.Respo
values := make(map[string][]string)
authorizer.Encoder().Encode(resp, values)
v := url.Values(values)
callback = fmt.Sprintf("%s#%s", authReq.RedirectURI, v.Encode())
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode())
}
http.Redirect(w, r, callback, http.StatusFound)
}

View file

@ -168,9 +168,9 @@ func (p *DefaultOP) Signe() u.Signer {
// return
}
func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
return AuthRequestError
}
// func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
// return AuthRequestError
// }
func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
Authorize(w, r, p)

View file

@ -3,6 +3,8 @@ package op
import (
"net/http"
"github.com/caos/oidc/pkg/op/u"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
@ -14,17 +16,17 @@ const (
type errorType string
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq u.ErrAuthRequest, err error) {
if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if authReq.RedirectURI == "" {
if authReq.GetRedirectURI() == "" {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
url := authReq.RedirectURI
if authReq.ResponseType == oidc.ResponseTypeCode {
url := authReq.GetRedirectURI()
if authReq.GetResponseType() == oidc.ResponseTypeCode {
url += "?"
} else {
url += "#"
@ -42,8 +44,8 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.Auth
if description != "" {
url += "&error_description=" + description
}
if authReq.State != "" {
url += "&state=" + authReq.State
if authReq.GetState() != "" {
url += "&state=" + authReq.GetState()
}
http.Redirect(w, r, url, http.StatusFound)
}
@ -77,17 +79,17 @@ var (
}
)
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest) {
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq u.AuthRequest) {
if authReq == nil {
http.Error(w, e.Error(), http.StatusBadRequest)
return
}
if authReq.RedirectURI == "" {
if authReq.GetRedirectURI() == "" {
http.Error(w, e.Error(), http.StatusBadRequest)
return
}
url := authReq.RedirectURI
if authReq.ResponseType == oidc.ResponseTypeCode {
url := authReq.GetRedirectURI()
if authReq.GetResponseType() == oidc.ResponseTypeCode {
url += "?"
} else {
url += "#"
@ -96,8 +98,8 @@ func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request,
if e.Description != "" {
url += "&error_description=" + e.Description
}
if authReq.State != "" {
url += "&state=" + authReq.State
if authReq.GetState() != "" {
url += "&state=" + authReq.GetState()
}
http.Redirect(w, r, url, http.StatusFound)
}

View file

@ -25,4 +25,5 @@ require (
github.com/gorilla/schema v1.1.0
github.com/stretchr/testify v1.4.0
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
gopkg.in/square/go-jose.v2 v2.4.0
)

View file

@ -5,6 +5,8 @@ import (
"net/http"
"time"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/op/u"
"github.com/caos/oidc/pkg/utils"
@ -52,7 +54,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage u.Storage, dec
ExchangeRequestError(w, r, err)
return
}
err = storage.DeleteAuthRequestAndCode(authReq.ID, tokenReq.Code)
err = storage.DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code)
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -62,7 +64,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage u.Storage, dec
ExchangeRequestError(w, r, err)
return
}
idToken, err := CreateIDToken(nil, "", nil)
idToken, err := CreateIDToken("", authReq, "", time.Now(), time.Now(), "", nil)
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -79,28 +81,31 @@ func CreateAccessToken() (string, error) {
return "accessToken", nil
}
func CreateIDToken(authReq *oidc.AuthRequest, atHash string, signer u.Signer) (string, error) {
var issuer, sub, acr string
var aud, amr []string
var exp, iat, authTime time.Time
func CreateIDToken(issuer string, authReq u.AuthRequest, sub string, exp, authTime time.Time, accessToken string, signer u.Signer) (string, error) {
var err error
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Subject: sub,
Audiences: aud,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: iat,
IssuedAt: time.Now().UTC(),
AuthTime: authTime,
Nonce: authReq.Nonce,
AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr,
AuthorizedParty: authReq.ClientID,
AccessTokenHash: atHash,
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
}
if accessToken != "" {
var alg jose.SignatureAlgorithm
claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, alg) //TODO: alg
if err != nil {
return "", err
}
}
return signer.Sign(claims)
}
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage u.Storage) (oidc.Client, error) {
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage u.Storage) (u.Client, error) {
if tokenReq.ClientID == "" {
clientID, clientSecret, ok := r.BasicAuth()
if ok {

View file

@ -1,4 +1,4 @@
package oidc
package u
type Client interface {
RedirectURIs() []string

View file

@ -3,11 +3,30 @@ package u
import "github.com/caos/oidc/pkg/oidc"
type Storage interface {
CreateAuthRequest(*oidc.AuthRequest) error
GetClientByClientID(string) (oidc.Client, error)
AuthRequestByID(string) (*oidc.AuthRequest, error)
AuthRequestByCode(oidc.Client, string, string) (*oidc.AuthRequest, error)
AuthorizeClientIDSecret(string, string) (oidc.Client, error)
AuthorizeClientIDCodeVerifier(string, string) (oidc.Client, error)
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
GetClientByClientID(string) (Client, error)
AuthRequestByID(string) (AuthRequest, error)
AuthRequestByCode(Client, string, string) (AuthRequest, error)
AuthorizeClientIDSecret(string, string) (Client, error)
AuthorizeClientIDCodeVerifier(string, string) (Client, error)
DeleteAuthRequestAndCode(string, string) error
}
type ErrAuthRequest interface {
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetState() string
}
type AuthRequest interface {
GetID() string
GetACR() string
GetAMR() []string
GetAudience() []string
GetClientID() string
GetNonce() string
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetState() string
GetSubject() string
}

View file

@ -3,12 +3,9 @@ package rp
import (
"bytes"
"context"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"fmt"
"hash"
"strings"
"time"
@ -446,29 +443,12 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor
return nil //TODO: return error
}
tokenHash, err := getHashAlgorithm(sigAlgorithm)
actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm)
if err != nil {
return err
}
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
actual := base64.RawURLEncoding.EncodeToString(sum)
if actual != atHash {
return nil //TODO: error
}
return nil
}
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
switch sigAlgorithm {
case jose.RS256, jose.ES256, jose.PS256:
return sha256.New(), nil
case jose.RS384, jose.ES384, jose.PS384:
return sha512.New384(), nil
case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil
default:
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
}
}