fix: end session (#35)
* fix: handle code separately * fix: option to ignore expiration on id_token and error handling * fix: op handler as http.Handler * fix: terminate session possible wihtout id_token_hint
This commit is contained in:
parent
21dfd6c22e
commit
628bc4ed65
7 changed files with 52 additions and 48 deletions
|
@ -15,20 +15,27 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
port := "9998"
|
||||||
config := &op.Config{
|
config := &op.Config{
|
||||||
Issuer: "http://localhost:9998/",
|
Issuer: "http://localhost:9998/",
|
||||||
CryptoKey: sha256.Sum256([]byte("test")),
|
CryptoKey: sha256.Sum256([]byte("test")),
|
||||||
Port: "9998",
|
|
||||||
}
|
}
|
||||||
storage := mock.NewAuthStorage()
|
storage := mock.NewAuthStorage()
|
||||||
handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint(op.NewEndpoint("test")))
|
handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint(op.NewEndpoint("test")))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
router := handler.HttpHandler().Handler.(*mux.Router)
|
router := handler.HttpHandler().(*mux.Router)
|
||||||
router.Methods("GET").Path("/login").HandlerFunc(HandleLogin)
|
router.Methods("GET").Path("/login").HandlerFunc(HandleLogin)
|
||||||
router.Methods("POST").Path("/login").HandlerFunc(HandleCallback)
|
router.Methods("POST").Path("/login").HandlerFunc(HandleCallback)
|
||||||
op.Start(ctx, handler)
|
server := &http.Server{
|
||||||
|
Addr: ":" + port,
|
||||||
|
Handler: router,
|
||||||
|
}
|
||||||
|
err = server.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,6 @@ type Configuration interface {
|
||||||
KeysEndpoint() Endpoint
|
KeysEndpoint() Endpoint
|
||||||
|
|
||||||
AuthMethodPostSupported() bool
|
AuthMethodPostSupported() bool
|
||||||
|
|
||||||
Port() string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateIssuer(issuer string) error {
|
func ValidateIssuer(issuer string) error {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/caos/logging"
|
"github.com/caos/logging"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
"github.com/caos/oidc/pkg/rp"
|
"github.com/caos/oidc/pkg/rp"
|
||||||
)
|
)
|
||||||
|
@ -45,7 +46,7 @@ type DefaultOP struct {
|
||||||
signer Signer
|
signer Signer
|
||||||
verifier rp.Verifier
|
verifier rp.Verifier
|
||||||
crypto Crypto
|
crypto Crypto
|
||||||
http *http.Server
|
http http.Handler
|
||||||
decoder *schema.Decoder
|
decoder *schema.Decoder
|
||||||
encoder *schema.Encoder
|
encoder *schema.Encoder
|
||||||
interceptor HttpInterceptor
|
interceptor HttpInterceptor
|
||||||
|
@ -64,7 +65,6 @@ type Config struct {
|
||||||
// IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm},
|
// IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm},
|
||||||
// SubjectTypesSupported: []string{"public"},
|
// SubjectTypesSupported: []string{"public"},
|
||||||
// TokenEndpointAuthMethodsSupported:
|
// TokenEndpointAuthMethodsSupported:
|
||||||
Port string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type endpoints struct {
|
type endpoints struct {
|
||||||
|
@ -180,13 +180,10 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts .
|
||||||
p.signer = NewDefaultSigner(ctx, storage, keyCh)
|
p.signer = NewDefaultSigner(ctx, storage, keyCh)
|
||||||
go p.ensureKey(ctx, storage, keyCh, p.timer)
|
go p.ensureKey(ctx, storage, keyCh, p.timer)
|
||||||
|
|
||||||
p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience())
|
p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience(), rp.WithIgnoreExpiration())
|
||||||
|
|
||||||
|
p.http = CreateRouter(p, p.interceptor)
|
||||||
|
|
||||||
router := CreateRouter(p, p.interceptor)
|
|
||||||
p.http = &http.Server{
|
|
||||||
Addr: ":" + config.Port,
|
|
||||||
Handler: router,
|
|
||||||
}
|
|
||||||
p.decoder = schema.NewDecoder()
|
p.decoder = schema.NewDecoder()
|
||||||
p.decoder.IgnoreUnknownKeys(true)
|
p.decoder.IgnoreUnknownKeys(true)
|
||||||
|
|
||||||
|
@ -225,11 +222,7 @@ func (p *DefaultOP) AuthMethodPostSupported() bool {
|
||||||
return true //TODO: config
|
return true //TODO: config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DefaultOP) Port() string {
|
func (p *DefaultOP) HttpHandler() http.Handler {
|
||||||
return p.config.Port
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DefaultOP) HttpHandler() *http.Server {
|
|
||||||
return p.http
|
return p.http
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
22
pkg/op/op.go
22
pkg/op/op.go
|
@ -1,12 +1,10 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/handlers"
|
"github.com/gorilla/handlers"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
@ -26,7 +24,7 @@ type OpenIDProvider interface {
|
||||||
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
||||||
HandleEndSession(w http.ResponseWriter, r *http.Request)
|
HandleEndSession(w http.ResponseWriter, r *http.Request)
|
||||||
HandleKeys(w http.ResponseWriter, r *http.Request)
|
HandleKeys(w http.ResponseWriter, r *http.Request)
|
||||||
HttpHandler() *http.Server
|
HttpHandler() http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
type HttpInterceptor func(http.HandlerFunc) http.HandlerFunc
|
type HttpInterceptor func(http.HandlerFunc) http.HandlerFunc
|
||||||
|
@ -54,21 +52,3 @@ func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router {
|
||||||
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
|
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
func Start(ctx context.Context, o OpenIDProvider) {
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
err := o.HttpHandler().Shutdown(ctx)
|
|
||||||
if err != nil {
|
|
||||||
logrus.Error("graceful shutdown of oidc server failed")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
err := o.HttpHandler().ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
logrus.Panicf("oidc server serve failed: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
logrus.Infof("oidc server is listening on %s", o.Port())
|
|
||||||
}
|
|
||||||
|
|
|
@ -27,7 +27,11 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.Client.GetID())
|
var clientID string
|
||||||
|
if session.Client != nil {
|
||||||
|
clientID = session.Client.GetID()
|
||||||
|
}
|
||||||
|
err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, ErrServerError("error terminating session"))
|
RequestError(w, r, ErrServerError("error terminating session"))
|
||||||
return
|
return
|
||||||
|
@ -50,6 +54,9 @@ func ParseEndSessionRequest(r *http.Request, decoder *schema.Decoder) (*oidc.End
|
||||||
|
|
||||||
func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) {
|
func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) {
|
||||||
session := new(EndSessionRequest)
|
session := new(EndSessionRequest)
|
||||||
|
if req.IdTokenHint == "" {
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
claims, err := ender.IDTokenVerifier().Verify(ctx, "", req.IdTokenHint)
|
claims, err := ender.IDTokenVerifier().Verify(ctx, "", req.IdTokenHint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidRequest("id_token_hint invalid")
|
return nil, ErrInvalidRequest("id_token_hint invalid")
|
||||||
|
|
|
@ -46,13 +46,20 @@ func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ..
|
||||||
return &DefaultVerifier{config: conf, keySet: keySet}
|
return &DefaultVerifier{config: conf, keySet: keySet}
|
||||||
}
|
}
|
||||||
|
|
||||||
//WithIgnoreAudience will turn off audience claim (should only be used for id_token_hints)
|
//WithIgnoreAudience will turn off validation for audience claim (should only be used for id_token_hints)
|
||||||
func WithIgnoreAudience() func(*verifierConfig) {
|
func WithIgnoreAudience() func(*verifierConfig) {
|
||||||
return func(conf *verifierConfig) {
|
return func(conf *verifierConfig) {
|
||||||
conf.ignoreAudience = true
|
conf.ignoreAudience = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//WithIgnoreExpiration will turn off validation for expiration claim (should only be used for id_token_hints)
|
||||||
|
func WithIgnoreExpiration() func(*verifierConfig) {
|
||||||
|
return func(conf *verifierConfig) {
|
||||||
|
conf.ignoreExpiration = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//WithIgnoreIssuedAt will turn off iat claim verification
|
//WithIgnoreIssuedAt will turn off iat claim verification
|
||||||
func WithIgnoreIssuedAt() func(*verifierConfig) {
|
func WithIgnoreIssuedAt() func(*verifierConfig) {
|
||||||
return func(conf *verifierConfig) {
|
return func(conf *verifierConfig) {
|
||||||
|
@ -108,6 +115,7 @@ type verifierConfig struct {
|
||||||
clientID string
|
clientID string
|
||||||
nonce string
|
nonce string
|
||||||
ignoreAudience bool
|
ignoreAudience bool
|
||||||
|
ignoreExpiration bool
|
||||||
iat *iatConfig
|
iat *iatConfig
|
||||||
acr ACRVerifier
|
acr ACRVerifier
|
||||||
maxAge time.Duration
|
maxAge time.Duration
|
||||||
|
@ -275,10 +283,10 @@ func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString stri
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if len(jws.Signatures) == 0 {
|
if len(jws.Signatures) == 0 {
|
||||||
return "", nil //TODO: error
|
return "", ErrSignatureMissing()
|
||||||
}
|
}
|
||||||
if len(jws.Signatures) > 1 {
|
if len(jws.Signatures) > 1 {
|
||||||
return "", nil //TODO: error
|
return "", ErrSignatureMultiple()
|
||||||
}
|
}
|
||||||
sig := jws.Signatures[0]
|
sig := jws.Signatures[0]
|
||||||
supportedSigAlgs := v.config.supportedSignAlgs
|
supportedSigAlgs := v.config.supportedSignAlgs
|
||||||
|
@ -292,16 +300,18 @@ func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString stri
|
||||||
signedPayload, err := v.keySet.VerifySignature(ctx, jws)
|
signedPayload, err := v.keySet.VerifySignature(ctx, jws)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
//TODO:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(signedPayload, payload) {
|
if !bytes.Equal(signedPayload, payload) {
|
||||||
return "", ErrSignatureInvalidPayload() //TODO: err
|
return "", ErrSignatureInvalidPayload()
|
||||||
}
|
}
|
||||||
return jose.SignatureAlgorithm(sig.Header.Algorithm), nil
|
return jose.SignatureAlgorithm(sig.Header.Algorithm), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *DefaultVerifier) checkExpiration(expiration time.Time) error {
|
func (v *DefaultVerifier) checkExpiration(expiration time.Time) error {
|
||||||
|
if v.config.ignoreExpiration {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
expiration = expiration.Round(time.Second)
|
expiration = expiration.Round(time.Second)
|
||||||
if !v.now().Before(expiration) {
|
if !v.now().Before(expiration) {
|
||||||
return ErrExpInvalid(expiration)
|
return ErrExpInvalid(expiration)
|
||||||
|
@ -362,8 +372,8 @@ func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
||||||
if atHash == "" {
|
if accessToken == "" {
|
||||||
return nil //TODO: return error
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
|
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
|
||||||
|
@ -371,7 +381,7 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if actual != atHash {
|
if actual != atHash {
|
||||||
return nil //TODO: error
|
return ErrAtHash()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,9 +40,18 @@ var (
|
||||||
ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError {
|
ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError {
|
||||||
return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime))
|
return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime))
|
||||||
}
|
}
|
||||||
|
ErrSignatureMissing = func() *validationError {
|
||||||
|
return ValidationError("id_token does not contain a signature")
|
||||||
|
}
|
||||||
|
ErrSignatureMultiple = func() *validationError {
|
||||||
|
return ValidationError("id_token contains multiple signatures")
|
||||||
|
}
|
||||||
ErrSignatureInvalidPayload = func() *validationError {
|
ErrSignatureInvalidPayload = func() *validationError {
|
||||||
return ValidationError("Signature does not match Payload")
|
return ValidationError("Signature does not match Payload")
|
||||||
}
|
}
|
||||||
|
ErrAtHash = func() *validationError {
|
||||||
|
return ValidationError("at_hash does not correspond to access token")
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func ValidationError(message string, args ...interface{}) *validationError {
|
func ValidationError(message string, args ...interface{}) *validationError {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue