interface

This commit is contained in:
Livio Amstutz 2019-11-28 15:29:19 +01:00
parent 80eeee2de2
commit 988a556fa9
10 changed files with 131 additions and 76 deletions

View file

@ -58,6 +58,24 @@ type AuthRequest struct {
ACRValues []string `schema:"acr_values"` ACRValues []string `schema:"acr_values"`
} }
// func (a *AuthRequest) GetID() string {
// return a.ID
// }
// func (a *AuthRequest) GetClientID() string {
// return a.ClientID
// }
func (a *AuthRequest) GetRedirectURI() string {
return a.RedirectURI
}
func (a *AuthRequest) GetResponseType() ResponseType {
return a.ResponseType
}
func (a *AuthRequest) GetState() string {
return a.State
}
type TokenRequest interface { type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"` // GrantType GrantType `schema:"grant_type"`
GrantType() GrantType GrantType() GrantType

View file

@ -1,7 +1,12 @@
package oidc package oidc
import ( import (
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt"
"hash"
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -82,3 +87,27 @@ type Tokens struct {
*oauth2.Token *oauth2.Token
IDTokenClaims *IDTokenClaims IDTokenClaims *IDTokenClaims
} }
func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
tokenHash, err := getHashAlgorithm(sigAlgorithm)
if err != nil {
return "", err
}
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
return base64.RawURLEncoding.EncodeToString(sum), nil
}
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
switch sigAlgorithm {
case jose.RS256, jose.ES256, jose.PS256:
return sha256.New(), nil
case jose.RS384, jose.ES384, jose.PS384:
return sha512.New384(), nil
case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil
default:
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
}
}

View file

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/schema" "github.com/gorilla/schema"
@ -19,7 +20,7 @@ type Authorizer interface {
Decoder() *schema.Decoder Decoder() *schema.Decoder
Encoder() *schema.Encoder Encoder() *schema.Encoder
Signe() u.Signer Signe() u.Signer
ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) // ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
} }
// type Signer interface { // type Signer interface {
@ -32,7 +33,7 @@ type ValidationAuthorizer interface {
} }
// type errorHandler func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) // type errorHandler func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
type callbackHandler func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) // type callbackHandler func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request)
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err := r.ParseForm() err := r.ParseForm()
@ -58,18 +59,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
return return
} }
err = authorizer.Storage().CreateAuthRequest(authReq) req, err := authorizer.Storage().CreateAuthRequest(authReq)
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err) AuthRequestError(w, r, authReq, err)
return return
} }
client, err := authorizer.Storage().GetClientByClientID(authReq.ClientID) client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err) AuthRequestError(w, r, req, err)
return return
} }
RedirectToLogin(authReq, client, w, r) RedirectToLogin(req, client, w, r)
} }
func ValidateAuthRequest(authReq *oidc.AuthRequest, storage u.Storage) error { func ValidateAuthRequest(authReq *oidc.AuthRequest, storage u.Storage) error {
@ -115,15 +116,15 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons
return nil return nil
} }
if responseType == oidc.ResponseTypeCode { if responseType == oidc.ResponseTypeCode {
if strings.HasPrefix(uri, "http://") && oidc.IsConfidentialType(client) { if strings.HasPrefix(uri, "http://") && u.IsConfidentialType(client) {
return nil return nil
} }
if client.ApplicationType() == oidc.ApplicationTypeNative { if client.ApplicationType() == u.ApplicationTypeNative {
return nil return nil
} }
return ErrInvalidRequest("redirect_uri not allowed 2") return ErrInvalidRequest("redirect_uri not allowed 2")
} else { } else {
if client.ApplicationType() != oidc.ApplicationTypeNative { if client.ApplicationType() != u.ApplicationTypeNative {
return ErrInvalidRequest("redirect_uri not allowed 3") return ErrInvalidRequest("redirect_uri not allowed 3")
} }
if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) { if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) {
@ -133,8 +134,8 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons
return nil return nil
} }
func RedirectToLogin(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { func RedirectToLogin(authReq u.AuthRequest, client u.Client, w http.ResponseWriter, r *http.Request) {
login := client.LoginURL(authReq.ID) login := client.LoginURL(authReq.GetID())
http.Redirect(w, r, login, http.StatusFound) http.Redirect(w, r, login, http.StatusFound)
} }
@ -150,20 +151,20 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
AuthResponse(authReq, authorizer, w, r) AuthResponse(authReq, authorizer, w, r)
} }
func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { func AuthResponse(authReq u.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
var callback string var callback string
if authReq.ResponseType == oidc.ResponseTypeCode { if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback = fmt.Sprintf("%s?code=%s", authReq.RedirectURI, "test") callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test")
} else { } else {
var accessToken string var accessToken string
var err error var err error
if authReq.ResponseType != oidc.ResponseTypeIDTokenOnly { if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
accessToken, err = CreateAccessToken() accessToken, err = CreateAccessToken()
if err != nil { if err != nil {
} }
} }
idToken, err := CreateIDToken(authReq, accessToken, authorizer.Signe()) idToken, err := CreateIDToken("", authReq, accessToken, time.Now(), time.Now(), "", authorizer.Signe())
if err != nil { if err != nil {
} }
@ -175,7 +176,7 @@ func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.Respo
values := make(map[string][]string) values := make(map[string][]string)
authorizer.Encoder().Encode(resp, values) authorizer.Encoder().Encode(resp, values)
v := url.Values(values) v := url.Values(values)
callback = fmt.Sprintf("%s#%s", authReq.RedirectURI, v.Encode()) callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode())
} }
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
} }

View file

@ -168,9 +168,9 @@ func (p *DefaultOP) Signe() u.Signer {
// return // return
} }
func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { // func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
return AuthRequestError // return AuthRequestError
} // }
func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) { func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
Authorize(w, r, p) Authorize(w, r, p)

View file

@ -3,6 +3,8 @@ package op
import ( import (
"net/http" "net/http"
"github.com/caos/oidc/pkg/op/u"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
@ -14,17 +16,17 @@ const (
type errorType string type errorType string
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq u.ErrAuthRequest, err error) {
if authReq == nil { if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
if authReq.RedirectURI == "" { if authReq.GetRedirectURI() == "" {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
url := authReq.RedirectURI url := authReq.GetRedirectURI()
if authReq.ResponseType == oidc.ResponseTypeCode { if authReq.GetResponseType() == oidc.ResponseTypeCode {
url += "?" url += "?"
} else { } else {
url += "#" url += "#"
@ -42,8 +44,8 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.Auth
if description != "" { if description != "" {
url += "&error_description=" + description url += "&error_description=" + description
} }
if authReq.State != "" { if authReq.GetState() != "" {
url += "&state=" + authReq.State url += "&state=" + authReq.GetState()
} }
http.Redirect(w, r, url, http.StatusFound) http.Redirect(w, r, url, http.StatusFound)
} }
@ -77,17 +79,17 @@ var (
} }
) )
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest) { func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq u.AuthRequest) {
if authReq == nil { if authReq == nil {
http.Error(w, e.Error(), http.StatusBadRequest) http.Error(w, e.Error(), http.StatusBadRequest)
return return
} }
if authReq.RedirectURI == "" { if authReq.GetRedirectURI() == "" {
http.Error(w, e.Error(), http.StatusBadRequest) http.Error(w, e.Error(), http.StatusBadRequest)
return return
} }
url := authReq.RedirectURI url := authReq.GetRedirectURI()
if authReq.ResponseType == oidc.ResponseTypeCode { if authReq.GetResponseType() == oidc.ResponseTypeCode {
url += "?" url += "?"
} else { } else {
url += "#" url += "#"
@ -96,8 +98,8 @@ func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request,
if e.Description != "" { if e.Description != "" {
url += "&error_description=" + e.Description url += "&error_description=" + e.Description
} }
if authReq.State != "" { if authReq.GetState() != "" {
url += "&state=" + authReq.State url += "&state=" + authReq.GetState()
} }
http.Redirect(w, r, url, http.StatusFound) http.Redirect(w, r, url, http.StatusFound)
} }

View file

@ -25,4 +25,5 @@ require (
github.com/gorilla/schema v1.1.0 github.com/gorilla/schema v1.1.0
github.com/stretchr/testify v1.4.0 github.com/stretchr/testify v1.4.0
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
gopkg.in/square/go-jose.v2 v2.4.0
) )

View file

@ -5,6 +5,8 @@ import (
"net/http" "net/http"
"time" "time"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/op/u" "github.com/caos/oidc/pkg/op/u"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
@ -52,7 +54,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage u.Storage, dec
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
} }
err = storage.DeleteAuthRequestAndCode(authReq.ID, tokenReq.Code) err = storage.DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code)
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
@ -62,7 +64,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage u.Storage, dec
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
} }
idToken, err := CreateIDToken(nil, "", nil) idToken, err := CreateIDToken("", authReq, "", time.Now(), time.Now(), "", nil)
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
@ -79,28 +81,31 @@ func CreateAccessToken() (string, error) {
return "accessToken", nil return "accessToken", nil
} }
func CreateIDToken(authReq *oidc.AuthRequest, atHash string, signer u.Signer) (string, error) { func CreateIDToken(issuer string, authReq u.AuthRequest, sub string, exp, authTime time.Time, accessToken string, signer u.Signer) (string, error) {
var issuer, sub, acr string var err error
var aud, amr []string
var exp, iat, authTime time.Time
claims := &oidc.IDTokenClaims{ claims := &oidc.IDTokenClaims{
Issuer: issuer, Issuer: issuer,
Subject: sub, Subject: authReq.GetSubject(),
Audiences: aud, Audiences: authReq.GetAudience(),
Expiration: exp, Expiration: exp,
IssuedAt: iat, IssuedAt: time.Now().UTC(),
AuthTime: authTime, AuthTime: authTime,
Nonce: authReq.Nonce, Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: acr, AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: amr, AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.ClientID, AuthorizedParty: authReq.GetClientID(),
AccessTokenHash: atHash, }
if accessToken != "" {
var alg jose.SignatureAlgorithm
claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, alg) //TODO: alg
if err != nil {
return "", err
}
} }
return signer.Sign(claims) return signer.Sign(claims)
} }
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage u.Storage) (oidc.Client, error) { func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage u.Storage) (u.Client, error) {
if tokenReq.ClientID == "" { if tokenReq.ClientID == "" {
clientID, clientSecret, ok := r.BasicAuth() clientID, clientSecret, ok := r.BasicAuth()
if ok { if ok {

View file

@ -1,4 +1,4 @@
package oidc package u
type Client interface { type Client interface {
RedirectURIs() []string RedirectURIs() []string

View file

@ -3,11 +3,30 @@ package u
import "github.com/caos/oidc/pkg/oidc" import "github.com/caos/oidc/pkg/oidc"
type Storage interface { type Storage interface {
CreateAuthRequest(*oidc.AuthRequest) error CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
GetClientByClientID(string) (oidc.Client, error) GetClientByClientID(string) (Client, error)
AuthRequestByID(string) (*oidc.AuthRequest, error) AuthRequestByID(string) (AuthRequest, error)
AuthRequestByCode(oidc.Client, string, string) (*oidc.AuthRequest, error) AuthRequestByCode(Client, string, string) (AuthRequest, error)
AuthorizeClientIDSecret(string, string) (oidc.Client, error) AuthorizeClientIDSecret(string, string) (Client, error)
AuthorizeClientIDCodeVerifier(string, string) (oidc.Client, error) AuthorizeClientIDCodeVerifier(string, string) (Client, error)
DeleteAuthRequestAndCode(string, string) error DeleteAuthRequestAndCode(string, string) error
} }
type ErrAuthRequest interface {
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetState() string
}
type AuthRequest interface {
GetID() string
GetACR() string
GetAMR() []string
GetAudience() []string
GetClientID() string
GetNonce() string
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetState() string
GetSubject() string
}

View file

@ -3,12 +3,9 @@ package rp
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/sha256"
"crypto/sha512"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"hash"
"strings" "strings"
"time" "time"
@ -446,29 +443,12 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor
return nil //TODO: return error return nil //TODO: return error
} }
tokenHash, err := getHashAlgorithm(sigAlgorithm) actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm)
if err != nil { if err != nil {
return err return err
} }
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
actual := base64.RawURLEncoding.EncodeToString(sum)
if actual != atHash { if actual != atHash {
return nil //TODO: error return nil //TODO: error
} }
return nil return nil
} }
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
switch sigAlgorithm {
case jose.RS256, jose.ES256, jose.PS256:
return sha256.New(), nil
case jose.RS384, jose.ES384, jose.PS384:
return sha512.New384(), nil
case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil
default:
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
}
}