add context
This commit is contained in:
parent
0731a62833
commit
462b5c83cd
12 changed files with 104 additions and 98 deletions
|
@ -1,6 +1,7 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -31,13 +32,13 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
return
|
||||
}
|
||||
|
||||
authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger)
|
||||
authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = exchanger.Storage().DeleteAuthRequest(authReq.GetID())
|
||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
|
@ -81,8 +82,8 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
|
|||
return tokenReq, nil
|
||||
}
|
||||
|
||||
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||
authReq, client, err := AuthorizeClient(tokenReq, exchanger)
|
||||
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -95,44 +96,38 @@ func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exc
|
|||
return authReq, nil
|
||||
}
|
||||
|
||||
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
|
||||
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
switch client.GetAuthMethod() {
|
||||
case AuthMethodNone:
|
||||
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger.Storage())
|
||||
if client.GetAuthMethod() == AuthMethodNone {
|
||||
authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger.Storage())
|
||||
return authReq, client, err
|
||||
case AuthMethodPost:
|
||||
if !exchanger.AuthMethodPostSupported() {
|
||||
return nil, nil, errors.New("basic not supported")
|
||||
}
|
||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
case AuthMethodBasic:
|
||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
default:
|
||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
}
|
||||
if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
|
||||
return nil, nil, errors.New("basic not supported")
|
||||
}
|
||||
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
return authReq, client, nil
|
||||
}
|
||||
|
||||
func AuthorizeClientIDSecret(clientID, clientSecret string, storage OPStorage) error {
|
||||
return storage.AuthorizeClientIDSecret(clientID, clientSecret)
|
||||
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
|
||||
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
|
||||
}
|
||||
|
||||
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
|
||||
func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
|
||||
if tokenReq.CodeVerifier == "" {
|
||||
return nil, ErrInvalidRequest("code_challenge required")
|
||||
}
|
||||
authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage)
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, nil, storage)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
|
@ -142,12 +137,12 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
|
|||
return authReq, nil
|
||||
}
|
||||
|
||||
func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
||||
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
||||
id, err := crypto.Decrypt(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storage.AuthRequestByID(id)
|
||||
return storage.AuthRequestByID(ctx, id)
|
||||
}
|
||||
|
||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue