add context

This commit is contained in:
Livio Amstutz 2019-12-18 16:05:21 +01:00
parent 0731a62833
commit 462b5c83cd
12 changed files with 104 additions and 98 deletions

View file

@ -1,6 +1,7 @@
package op
import (
"context"
"errors"
"net/http"
"time"
@ -31,13 +32,13 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
return
}
authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger)
authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = exchanger.Storage().DeleteAuthRequest(authReq.GetID())
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -81,8 +82,8 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
return tokenReq, nil
}
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
authReq, client, err := AuthorizeClient(tokenReq, exchanger)
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
if err != nil {
return nil, err
}
@ -95,44 +96,38 @@ func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exc
return authReq, nil
}
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil {
return nil, nil, err
}
switch client.GetAuthMethod() {
case AuthMethodNone:
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger.Storage())
if client.GetAuthMethod() == AuthMethodNone {
authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger.Storage())
return authReq, client, err
case AuthMethodPost:
if !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("basic not supported")
}
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
case AuthMethodBasic:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
default:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
}
if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("basic not supported")
}
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
if err != nil {
return nil, nil, err
}
authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
if err != nil {
return nil, nil, ErrInvalidRequest("invalid code")
}
return authReq, client, nil
}
func AuthorizeClientIDSecret(clientID, clientSecret string, storage OPStorage) error {
return storage.AuthorizeClientIDSecret(clientID, clientSecret)
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
}
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required")
}
authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage)
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, nil, storage)
if err != nil {
return nil, ErrInvalidRequest("invalid code")
}
@ -142,12 +137,12 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
return authReq, nil
}
func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
id, err := crypto.Decrypt(code)
if err != nil {
return nil, err
}
return storage.AuthRequestByID(id)
return storage.AuthRequestByID(ctx, id)
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {