From 8ee38d2ec8c71596a0e23b2430c1c56c66a3cc3d Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Thu, 28 Nov 2019 08:01:31 +0100 Subject: [PATCH] authreq --- example/go.sum | 2 + example/internal/mock/storage.go | 31 +++++++- pkg/oidc/authorization.go | 15 ++-- pkg/oidc/client.go | 24 ++++-- pkg/op/authrequest.go | 131 +++++++++++++++++++++++++------ pkg/op/authrequest_test.go | 120 +++++++++++++++++++++++----- pkg/op/default_op.go | 43 ++++++++-- pkg/op/error.go | 101 ++++++++++++++++++++++++ pkg/op/go.mod | 1 + pkg/op/go.sum | 2 + pkg/op/mock/storage.mock.impl.go | 41 ++++++++-- pkg/op/op.go | 3 + pkg/op/storage.go | 1 + pkg/op/tokenrequest.go | 39 ++++++--- 14 files changed, 469 insertions(+), 85 deletions(-) create mode 100644 pkg/op/error.go diff --git a/example/go.sum b/example/go.sum index bf4723b..f942ed2 100644 --- a/example/go.sum +++ b/example/go.sum @@ -29,6 +29,8 @@ github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 6857d4c..0092334 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -1,6 +1,8 @@ package mock import ( + "errors" + "github.com/caos/oidc/pkg/oidc" ) @@ -11,7 +13,10 @@ func (s *Storage) CreateAuthRequest(authReq *oidc.AuthRequest) error { authReq.ID = "id" return nil } -func (s *Storage) GetClientByClientID(string) (oidc.Client, error) { +func (s *Storage) GetClientByClientID(id string) (oidc.Client, error) { + if id == "not" { + return nil, errors.New("not found") + } return &ConfClient{}, nil } func (s *Storage) AuthRequestByCode(oidc.Client, string, string) (*oidc.AuthRequest, error) { @@ -26,12 +31,26 @@ func (s *Storage) AuthorizeClientIDCodeVerifier(string, string) (oidc.Client, er func (s *Storage) DeleteAuthRequestAndCode(string, string) error { return nil } +func (s *Storage) AuthRequestByID(id string) (*oidc.AuthRequest, error) { + if id == "none" { + return nil, errors.New("not found") + } + var responseType oidc.ResponseType + if id == "code" { + responseType = oidc.ResponseTypeCode + } else if id == "id" { + responseType = oidc.ResponseTypeIDTokenOnly + } else { + responseType = oidc.ResponseTypeIDToken + } + return &oidc.AuthRequest{ + ResponseType: responseType, + RedirectURI: "/callback", + }, nil +} type ConfClient struct{} -func (c *ConfClient) Type() oidc.ClientType { - return oidc.ClientTypeConfidential -} func (c *ConfClient) RedirectURIs() []string { return []string{ "https://registered.com/callback", @@ -43,3 +62,7 @@ func (c *ConfClient) RedirectURIs() []string { func (c *ConfClient) LoginURL(id string) string { return "login?id=" + id } + +func (c *ConfClient) ApplicationType() oidc.ApplicationType { + return oidc.ApplicationTypeWeb +} diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index e3c2d24..d4927ab 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -78,18 +78,13 @@ func (a *AccessTokenRequest) GrantType() GrantType { } type AccessTokenResponse struct { - AccessToken string `json:"access_token,omitempty"` - TokenType string `json:"token_type,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - ExpiresIn uint64 `json:"expires_in,omitempty"` - IDToken string `json:"id_token,omitempty"` + AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` + TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` + ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` + IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` } -// func (a AccessTokenRequest) UnmarshalText(text []byte) error { -// fmt.Println(string(text)) -// return nil -// } - type TokenExchangeRequest struct { subjectToken string `schema:"subject_token"` subjectTokenType string `schema:"subject_token_type"` diff --git a/pkg/oidc/client.go b/pkg/oidc/client.go index c156175..fe243b2 100644 --- a/pkg/oidc/client.go +++ b/pkg/oidc/client.go @@ -2,17 +2,29 @@ package oidc type Client interface { RedirectURIs() []string - Type() ClientType + ApplicationType() ApplicationType LoginURL(string) string } -type ClientType int +// type ClientType int -func (c ClientType) IsConvidential() bool { - return c == ClientTypeConfidential +// func (c ClientType) IsConvidential() bool { +// return c == ClientTypeConfidential +// } + +func IsConfidentialType(c Client) bool { + return c.ApplicationType() == ApplicationTypeWeb } +type ApplicationType int + +// const (a ApplicationType) + const ( - ClientTypeConfidential ClientType = iota - ClientTypePublic + // ClientTypeConfidential ClientType = iota + // ClientTypePublic + + ApplicationTypeWeb ApplicationType = iota + ApplicationTypeUserAgent + ApplicationTypeNative ) diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 1c78740..df86df4 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -1,54 +1,76 @@ package op import ( - "errors" + "fmt" "net/http" + "net/url" + "strings" + "github.com/gorilla/mux" "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" str_utils "github.com/caos/utils/strings" ) -func Authorize(w http.ResponseWriter, r *http.Request, storage Storage) (*oidc.AuthRequest, error) { +type Authorizer interface { + Storage() Storage + Decoder() *schema.Decoder + Encoder() *schema.Encoder + Signer() Signer +} + +type ValidationAuthorizer interface { + Authorizer + ValidateAuthRequest(*oidc.AuthRequest, Storage) error +} + +func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { err := r.ParseForm() if err != nil { - return nil, errors.New("Unimplemented") //TODO: impl + AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form: %v", err)) + return } authReq := new(oidc.AuthRequest) - //TODO: - d := schema.NewDecoder() - d.IgnoreUnknownKeys(true) + err = authorizer.Decoder().Decode(authReq, r.Form) + if err != nil { + AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err))) + return + } - err = d.Decode(authReq, r.Form) - if err != nil { - return nil, err + validation := ValidateAuthRequest + if validater, ok := authorizer.(ValidationAuthorizer); ok { + validation = validater.ValidateAuthRequest } - if err = ValidateAuthRequest(authReq, storage); err != nil { - return nil, err + if err := validation(authReq, authorizer.Storage()); err != nil { + AuthRequestError(w, r, authReq, err) + return } - err = storage.CreateAuthRequest(authReq) + + err = authorizer.Storage().CreateAuthRequest(authReq) if err != nil { - //TODO: return err + AuthRequestError(w, r, authReq, err) + return } - client, err := storage.GetClientByClientID(authReq.ClientID) + + client, err := authorizer.Storage().GetClientByClientID(authReq.ClientID) if err != nil { - return nil, err + AuthRequestError(w, r, authReq, err) + return } RedirectToLogin(authReq, client, w, r) - return nil, nil } func ValidateAuthRequest(authReq *oidc.AuthRequest, storage Storage) error { if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { return err } - if err := ValidateAuthReqRedirectURI(authReq.RedirectURI, authReq.ClientID, storage); err != nil { + if err := ValidateAuthReqRedirectURI(authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil { return err } return nil - return errors.New("Unimplemented") //TODO: impl https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.2.2 + // return errors.New("Unimplemented") //TODO: impl https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.2.2 // if NeedsExistingSession(authRequest) { // session, err := storage.CheckSession(authRequest) @@ -60,24 +82,43 @@ func ValidateAuthRequest(authReq *oidc.AuthRequest, storage Storage) error { func ValidateAuthReqScopes(scopes []string) error { if len(scopes) == 0 { - return errors.New("scope missing") + return ErrInvalidRequest("scope missing") } if !str_utils.Contains(scopes, oidc.ScopeOpenID) { - return errors.New("scope openid missing") + return ErrInvalidRequest("scope openid missing") } return nil } -func ValidateAuthReqRedirectURI(uri, client_id string, storage Storage) error { +func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage Storage) error { if uri == "" { - return errors.New("redirect_uri must not be empty") //TODO: + return ErrInvalidRequest("redirect_uri must not be empty") } client, err := storage.GetClientByClientID(client_id) if err != nil { - return err + return ErrServerError(err.Error()) } if !str_utils.Contains(client.RedirectURIs(), uri) { - return errors.New("redirect_uri not allowed") + return ErrInvalidRequest("redirect_uri not allowed") + } + if strings.HasPrefix(uri, "https://") { + return nil + } + if responseType == oidc.ResponseTypeCode { + if strings.HasPrefix(uri, "http://") && oidc.IsConfidentialType(client) { + return nil + } + if client.ApplicationType() == oidc.ApplicationTypeNative { + return nil + } + return ErrInvalidRequest("redirect_uri not allowed 2") + } else { + if client.ApplicationType() != oidc.ApplicationTypeNative { + return ErrInvalidRequest("redirect_uri not allowed 3") + } + if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) { + return ErrInvalidRequest("redirect_uri not allowed 4") + } } return nil } @@ -86,3 +127,45 @@ func RedirectToLogin(authReq *oidc.AuthRequest, client oidc.Client, w http.Respo login := client.LoginURL(authReq.ID) http.Redirect(w, r, login, http.StatusFound) } + +func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { + params := mux.Vars(r) + id := params["id"] + + authReq, err := authorizer.Storage().AuthRequestByID(id) + if err != nil { + AuthRequestError(w, r, nil, err) + return + } + AuthResponse(authReq, authorizer, w, r) +} + +func AuthResponse(authReq *oidc.AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { + var callback string + if authReq.ResponseType == oidc.ResponseTypeCode { + callback = fmt.Sprintf("%s?code=%s", authReq.RedirectURI, "test") + } else { + var accessToken string + var err error + if authReq.ResponseType != oidc.ResponseTypeIDTokenOnly { + accessToken, err = CreateAccessToken() + if err != nil { + + } + } + idToken, err := CreateIDToken(authReq, accessToken, authorizer.Signer()) + if err != nil { + + } + resp := &oidc.AccessTokenResponse{ + AccessToken: accessToken, + IDToken: idToken, + TokenType: "Bearer", + } + values := make(map[string][]string) + authorizer.Encoder().Encode(resp, values) + v := url.Values(values) + callback = fmt.Sprintf("%s#%s", authReq.RedirectURI, v.Encode()) + } + http.Redirect(w, r, callback, http.StatusFound) +} diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go index c01f8eb..defe085 100644 --- a/pkg/op/authrequest_test.go +++ b/pkg/op/authrequest_test.go @@ -1,12 +1,15 @@ package op import ( + "net/http" + "net/http/httptest" "testing" - "github.com/caos/oidc/pkg/op" - "github.com/caos/oidc/pkg/op/mock" + "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" + "github.com/caos/oidc/pkg/op/mock" ) func TestValidateAuthRequest(t *testing.T) { @@ -58,11 +61,12 @@ func TestValidateAuthRequest(t *testing.T) { } } -func TestValidateRedirectURI(t *testing.T) { +func TestValidateAuthReqRedirectURI(t *testing.T) { type args struct { - uri string - clientID string - storage op.Storage + uri string + clientID string + responseType oidc.ResponseType + storage op.Storage } tests := []struct { name string @@ -71,40 +75,120 @@ func TestValidateRedirectURI(t *testing.T) { }{ { "empty fails", - args{"", "", nil}, + args{"", "", oidc.ResponseTypeCode, nil}, true, }, { "unregistered fails", - args{"https://unregistered.com/callback", "client_id", mock.NewMockStorageExpectValidClientID(t)}, + args{"https://unregistered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, true, }, { - "http not allowed fails", - args{"http://registered.com/callback", "client_id", mock.NewMockStorageExpectValidClientID(t)}, + "storage error fails", + args{"https://registered.com/callback", "non_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectInvalidClientID(t)}, true, }, { - "registered https ok", - args{"https://registered.com/callback", "client_id", mock.NewMockStorageExpectValidClientID(t)}, + "code flow registered http not confidential fails", + args{"http://registered.com/callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "code flow registered http confidential ok", + args{"http://registered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, false, }, { - "registered http allowed ok", - args{"http://localhost:9999/callback", "client_id", mock.NewMockStorageExpectValidClientID(t)}, + "code flow registered custom not native fails", + args{"custom://callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "code flow registered custom native ok", + args{"http://registered.com/callback", "native_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, false, }, { - "registered scheme ok", - args{"custom://callback", "client_id", mock.NewMockStorageExpectValidClientID(t)}, + "implicit flow registered ok", + args{"https://registered.com/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, false, }, + { + "implicit flow registered http localhost native ok", + args{"http://localhost:9999/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + false, + }, + { + "implicit flow registered http localhost user agent fails", + args{"http://localhost:9999/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "implicit flow http non localhost fails", + args{"http://registered.com/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "implicit flow custom fails", + args{"custom://callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := ValidateAuthReqRedirectURI(tt.args.uri, tt.args.clientID, tt.args.storage); (err != nil) != tt.wantErr { - t.Errorf("ValidateRedirectURI() error = %v, wantErr %v", err, tt.wantErr) + if err := ValidateAuthReqRedirectURI(tt.args.uri, tt.args.clientID, tt.args.responseType, tt.args.storage); (err != nil) != tt.wantErr { + t.Errorf("ValidateRedirectURI() error = %v, wantErr %v", err.Error(), tt.wantErr) } }) } } + +func TestValidateAuthReqScopes(t *testing.T) { + type args struct { + scopes []string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "scopes missing fails", args{}, true, + }, + { + "scope openid missing fails", args{[]string{"email"}}, true, + }, + { + "scope ok", args{[]string{"openid"}}, false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr { + t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAuthorize(t *testing.T) { + type args struct { + w http.ResponseWriter + r *http.Request + storage Storage + decoder *schema.Decoder + } + tests := []struct { + name string + args args + }{ + {"parsing fails", args{httptest.NewRecorder(), &http.Request{Method: "POST", Body: nil}, nil, nil}}, + {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, nil, schema.NewDecoder()}}, + {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, nil, schema.NewDecoder()}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Authorize(tt.args.w, tt.args.r, tt.args.storage, tt.args.decoder) + }) + } +} diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index c547f3e..e80a129 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -7,6 +7,8 @@ import ( "net/url" "strings" + "github.com/gorilla/schema" + "github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/oidc" @@ -18,6 +20,8 @@ type DefaultOP struct { discoveryConfig *oidc.DiscoveryConfiguration storage Storage http *http.Server + decoder *schema.Decoder + encoder *schema.Encoder } type Config struct { @@ -133,6 +137,10 @@ func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (Ope Addr: ":" + config.Port, Handler: router, } + p.decoder = schema.NewDecoder() + p.decoder.IgnoreUnknownKeys(true) + + p.encoder = schema.NewEncoder() return p, nil } @@ -157,7 +165,6 @@ func (e Endpoint) Validate() error { func (p *DefaultOP) AuthorizationEndpoint() Endpoint { return p.endpoints.Authorization - } func (p *DefaultOP) TokenEndpoint() Endpoint { @@ -180,11 +187,28 @@ func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) { utils.MarshalJSON(w, p.discoveryConfig) } +func (p *DefaultOP) Decoder() *schema.Decoder { + return p.decoder +} + +func (p *DefaultOP) Encoder() *schema.Encoder { + return p.encoder +} + +func (p *DefaultOP) Storage() Storage { + return p.storage +} + +func (p *DefaultOP) Signer() Signer { + // return p.signer + return nil +} + func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) { - _, err := Authorize(w, r, p.storage) - if err != nil { - http.Error(w, err.Error(), 400) - } + Authorize(w, r, p) + // if err != nil { + // http.Error(w, err.Error(), 400) + // } // authRequest, err := ParseAuthRequest(w, r) // if err != nil { // //TODO: return err @@ -203,13 +227,18 @@ func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) { // RedirectToLogin(authRequest, client, w, r) } +func (p *DefaultOP) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) { + AuthorizeCallback(w, r, p) +} + func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) { reqType := r.FormValue("grant_type") if reqType == "" { - //return errors.New("grant_type missing") //TODO: impl + ExchangeRequestError(w, r, nil, ErrInvalidRequest("grant_type missing")) + return } if reqType == string(oidc.GrantTypeCode) { - token, err := CodeExchange(w, r, p.storage) + token, err := CodeExchange(w, r, p.storage, p.decoder) if err != nil { } diff --git a/pkg/op/error.go b/pkg/op/error.go new file mode 100644 index 0000000..1aacc27 --- /dev/null +++ b/pkg/op/error.go @@ -0,0 +1,101 @@ +package op + +import ( + "net/http" + + "github.com/caos/oidc/pkg/oidc" +) + +const ( + InvalidRequest errorType = "invalid_request" + ServerError errorType = "server_error" +) + +type errorType string + +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { + if authReq == nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if authReq.RedirectURI == "" { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + url := authReq.RedirectURI + if authReq.ResponseType == oidc.ResponseTypeCode { + url += "?" + } else { + url += "#" + } + var errorType errorType + var description string + if e, ok := err.(*OAuthError); ok { + errorType = e.ErrorType + description = e.Description + } else { + errorType = ServerError + description = err.Error() + } + url += "error=" + string(errorType) + if description != "" { + url += "&error_description=" + description + } + if authReq.State != "" { + url += "&state=" + authReq.State + } + http.Redirect(w, r, url, http.StatusFound) +} + +func ExchangeRequestError(w http.ResponseWriter, r *http.Request, exchangeReq *oidc.AuthRequest, err error) { + +} + +type OAuthError struct { + ErrorType errorType `json:"error"` + Description string `json:"description"` +} + +var ( + ErrInvalidRequest = func(description string, args ...interface{}) *OAuthError { + return &OAuthError{ + ErrorType: InvalidRequest, + Description: description, + } + } + ErrServerError = func(description string, args ...interface{}) *OAuthError { + return &OAuthError{ + ErrorType: ServerError, + Description: description, + } + } +) + +func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest) { + if authReq == nil { + http.Error(w, e.Error(), http.StatusBadRequest) + return + } + if authReq.RedirectURI == "" { + http.Error(w, e.Error(), http.StatusBadRequest) + return + } + url := authReq.RedirectURI + if authReq.ResponseType == oidc.ResponseTypeCode { + url += "?" + } else { + url += "#" + } + url += "error=" + string(e.ErrorType) + if e.Description != "" { + url += "&error_description=" + e.Description + } + if authReq.State != "" { + url += "&state=" + authReq.State + } + http.Redirect(w, r, url, http.StatusFound) +} + +func (e *OAuthError) Error() string { + return "" +} diff --git a/pkg/op/go.mod b/pkg/op/go.mod index c99e2a5..3c73f65 100644 --- a/pkg/op/go.mod +++ b/pkg/op/go.mod @@ -18,6 +18,7 @@ require ( github.com/caos/utils v0.0.0-20191104132131-b318678afbef github.com/caos/utils/logging v0.0.0-20191104132131-b318678afbef github.com/golang/mock v1.3.1 + github.com/google/go-querystring v1.0.0 github.com/gorilla/mux v1.7.3 github.com/gorilla/schema v1.1.0 github.com/stretchr/testify v1.4.0 diff --git a/pkg/op/go.sum b/pkg/op/go.sum index 95796ea..844f1fa 100644 --- a/pkg/op/go.sum +++ b/pkg/op/go.sum @@ -29,6 +29,8 @@ github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY= diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index f4c002c..ca21159 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -1,11 +1,13 @@ package mock import ( + "errors" "testing" + "github.com/caos/oidc/pkg/oidc" + "github.com/golang/mock/gomock" - "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" ) @@ -19,6 +21,12 @@ func NewMockStorageExpectValidClientID(t *testing.T) op.Storage { return m } +func NewMockStorageExpectInvalidClientID(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectInvalidClientID(m) + return m +} + func NewMockStorageAny(t *testing.T) op.Storage { m := NewStorage(t) mockS := m.(*MockStorage) @@ -27,19 +35,36 @@ func NewMockStorageAny(t *testing.T) op.Storage { return m } +func ExpectInvalidClientID(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetClientByClientID(gomock.Any()).Return(nil, errors.New("client not found")) +} + func ExpectValidClientID(s op.Storage) { mockS := s.(*MockStorage) - mockS.EXPECT().GetClientByClientID(gomock.Any()).Return(&ConfClient{}, nil) + mockS.EXPECT().GetClientByClientID(gomock.Any()).DoAndReturn( + func(id string) (oidc.Client, error) { + var appType oidc.ApplicationType + switch id { + case "web_client": + appType = oidc.ApplicationTypeWeb + case "native_client": + appType = oidc.ApplicationTypeNative + case "useragent_client": + appType = oidc.ApplicationTypeUserAgent + } + return &ConfClient{appType: appType}, nil + }) } -type ConfClient struct{} - -func (c *ConfClient) Type() oidc.ClientType { - return oidc.ClientTypeConfidential +type ConfClient struct { + appType oidc.ApplicationType } + func (c *ConfClient) RedirectURIs() []string { return []string{ "https://registered.com/callback", + "http://registered.com/callback", "http://localhost:9999/callback", "custom://callback", } @@ -48,3 +73,7 @@ func (c *ConfClient) RedirectURIs() []string { func (c *ConfClient) LoginURL(id string) string { return "login?id=" + id } + +func (c *ConfClient) ApplicationType() oidc.ApplicationType { + return c.appType +} diff --git a/pkg/op/op.go b/pkg/op/op.go index 5a439d5..e3f5f70 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -15,8 +15,10 @@ type OpenIDProvider interface { // Storage() Storage HandleDiscovery(w http.ResponseWriter, r *http.Request) HandleAuthorize(w http.ResponseWriter, r *http.Request) + HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) HandleExchange(w http.ResponseWriter, r *http.Request) HandleUserinfo(w http.ResponseWriter, r *http.Request) + // Storage() Storage HttpHandler() *http.Server } @@ -24,6 +26,7 @@ func CreateRouter(o OpenIDProvider) *mux.Router { router := mux.NewRouter() router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) router.HandleFunc(o.AuthorizationEndpoint().Relative(), o.HandleAuthorize) + router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", o.HandleAuthorizeCallback) router.HandleFunc(o.TokenEndpoint().Relative(), o.HandleExchange) router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo) return router diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 388bb22..5fcf652 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -5,6 +5,7 @@ import "github.com/caos/oidc/pkg/oidc" type Storage interface { CreateAuthRequest(*oidc.AuthRequest) error GetClientByClientID(string) (oidc.Client, error) + AuthRequestByID(string) (*oidc.AuthRequest, error) AuthRequestByCode(oidc.Client, string, string) (*oidc.AuthRequest, error) AuthorizeClientIDSecret(string, string) (oidc.Client, error) AuthorizeClientIDCodeVerifier(string, string) (oidc.Client, error) diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index b80b4f8..6f9664e 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -3,6 +3,7 @@ package op import ( "errors" "net/http" + "time" "github.com/gorilla/schema" @@ -20,18 +21,14 @@ import ( // return ParseTokenExchangeRequest(w, r) // } -func CodeExchange(w http.ResponseWriter, r *http.Request, storage Storage) (*oidc.AccessTokenResponse, error) { +func CodeExchange(w http.ResponseWriter, r *http.Request, storage Storage, decoder *schema.Decoder) (*oidc.AccessTokenResponse, error) { err := r.ParseForm() if err != nil { return nil, errors.New("Unimplemented") //TODO: impl } tokenReq := new(oidc.AccessTokenRequest) - //TODO: - d := schema.NewDecoder() - d.IgnoreUnknownKeys(true) - - err = d.Decode(tokenReq, r.Form) + err = decoder.Decode(tokenReq, r.Form) if err != nil { return nil, err } @@ -55,7 +52,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage Storage) (*oid if err != nil { } - idToken, err := CreateIDToken() + idToken, err := CreateIDToken(nil, "", nil) if err != nil { } @@ -67,10 +64,32 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, storage Storage) (*oid } func CreateAccessToken() (string, error) { - return "", nil + return "accessToken", nil } -func CreateIDToken() (string, error) { - return "", nil + +type Signer interface { + Sign(claims *oidc.IDTokenClaims) (string, error) +} + +func CreateIDToken(authReq *oidc.AuthRequest, atHash string, signer Signer) (string, error) { + var issuer, sub, acr string + var aud, amr []string + var exp, iat, authTime time.Time + + claims := &oidc.IDTokenClaims{ + Issuer: issuer, + Subject: sub, + Audiences: aud, + Expiration: exp, + IssuedAt: iat, + AuthTime: authTime, + Nonce: authReq.Nonce, + AuthenticationContextClassReference: acr, + AuthenticationMethodsReferences: amr, + AuthorizedParty: authReq.ClientID, + AccessTokenHash: atHash, + } + return signer.Sign(claims) } func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage Storage) (oidc.Client, error) {