token, errors and more

This commit is contained in:
Livio Amstutz 2019-12-03 08:53:39 +01:00
parent 89bcd1a0c3
commit f04e7cf5b9
9 changed files with 64 additions and 24 deletions

View file

@ -20,6 +20,12 @@ const (
stateParam = "state"
)
var (
DefaultErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
}
)
//DefaultRP impements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface
type DefaultRP struct {
endpoints Endpoints
@ -30,6 +36,8 @@ type DefaultRP struct {
httpClient *http.Client
cookieHandler *utils.CookieHandler
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
verifier Verifier
}
@ -51,6 +59,10 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc
return nil, err
}
if p.errorHandler == nil {
p.errorHandler = DefaultErrorHandler
}
if p.verifier == nil {
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint
}
@ -125,15 +137,16 @@ func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http
return
}
params := r.URL.Query()
if params.Get("code") != "" {
tokens, err := p.CodeExchange(r.Context(), params.Get("code"))
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return
}
callback(w, r, tokens, state)
if params.Get("error") != "" {
p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
return
}
w.Write([]byte(params.Get("error")))
tokens, err := p.CodeExchange(r.Context(), params.Get("code"))
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return
}
callback(w, r, tokens, state)
}
}
@ -169,18 +182,15 @@ func (p *DefaultRP) DelegationTokenExchange(ctx context.Context, subjectToken st
func (p *DefaultRP) discover() error {
wellKnown := strings.TrimSuffix(p.config.Issuer, "/") + oidc.DiscoveryEndpoint
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return err
}
discoveryConfig := new(oidc.DiscoveryConfiguration)
err = utils.HttpRequest(p.httpClient, req, &discoveryConfig)
if err != nil {
return err
}
p.endpoints = GetEndpoints(discoveryConfig)
p.oauthConfig = oauth2.Config{
ClientID: p.config.ClientID,