diff --git a/example/client/device/device.go b/example/client/device/device.go new file mode 100644 index 0000000..284ba37 --- /dev/null +++ b/example/client/device/device.go @@ -0,0 +1,61 @@ +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/sirupsen/logrus" + + "github.com/zitadel/oidc/v2/pkg/client/rp" + httphelper "github.com/zitadel/oidc/v2/pkg/http" +) + +var ( + key = []byte("test1234test1234") +) + +func main() { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT) + defer stop() + + clientID := os.Getenv("CLIENT_ID") + clientSecret := os.Getenv("CLIENT_SECRET") + keyPath := os.Getenv("KEY_PATH") + issuer := os.Getenv("ISSUER") + scopes := strings.Split(os.Getenv("SCOPES"), " ") + + cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) + + var options []rp.Option + if clientSecret == "" { + options = append(options, rp.WithPKCE(cookieHandler)) + } + if keyPath != "" { + options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) + } + + provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...) + if err != nil { + logrus.Fatalf("error creating provider %s", err.Error()) + } + + logrus.Info("starting device authorization flow") + resp, err := rp.DeviceAuthorization(scopes, provider) + if err != nil { + logrus.Fatal(err) + } + logrus.Info("resp", resp) + fmt.Printf("\nPlease browse to %s and enter code %s\n", resp.VerificationURI, resp.UserCode) + + logrus.Info("start polling") + token, err := rp.DeviceAccessToken(ctx, resp.DeviceCode, time.Duration(resp.Interval)*time.Second, provider) + if err != nil { + logrus.Fatal(err) + } + logrus.Infof("successfully obtained token: %v", token) +} diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go new file mode 100644 index 0000000..ae2e8f2 --- /dev/null +++ b/example/server/exampleop/device.go @@ -0,0 +1,191 @@ +package exampleop + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/gorilla/mux" + "github.com/gorilla/securecookie" + "github.com/sirupsen/logrus" + "github.com/zitadel/oidc/v2/pkg/op" +) + +type deviceAuthenticate interface { + CheckUsernamePasswordSimple(username, password string) error + op.DeviceAuthorizationStorage +} + +type deviceLogin struct { + storage deviceAuthenticate + cookie *securecookie.SecureCookie +} + +func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) { + l := &deviceLogin{ + storage: storage, + cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil), + } + + router.HandleFunc("", l.userCodeHandler) + router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler) + router.HandleFunc("/confirm", l.confirmHandler) +} + +func renderUserCode(w io.Writer, err error) { + data := struct { + Error string + }{ + Error: errMsg(err), + } + + if err := templates.ExecuteTemplate(w, "usercode", data); err != nil { + logrus.Error(err) + } +} + +func renderDeviceLogin(w http.ResponseWriter, userCode string, err error) { + data := &struct { + UserCode string + Error string + }{ + UserCode: userCode, + Error: errMsg(err), + } + if err = templates.ExecuteTemplate(w, "device_login", data); err != nil { + logrus.Error(err) + } +} + +func renderConfirmPage(w http.ResponseWriter, username, clientID string, scopes []string) { + data := &struct { + Username string + ClientID string + Scopes []string + }{ + Username: username, + ClientID: clientID, + Scopes: scopes, + } + if err := templates.ExecuteTemplate(w, "confirm_device", data); err != nil { + logrus.Error(err) + } +} + +func (d *deviceLogin) userCodeHandler(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + renderUserCode(w, err) + return + } + userCode := r.Form.Get("user_code") + if userCode == "" { + if prompt, _ := url.QueryUnescape(r.Form.Get("prompt")); prompt != "" { + err = errors.New(prompt) + } + renderUserCode(w, err) + return + } + + renderDeviceLogin(w, userCode, nil) +} + +func redirectBack(w http.ResponseWriter, r *http.Request, prompt string) { + values := make(url.Values) + values.Set("prompt", url.QueryEscape(prompt)) + + url := url.URL{ + Path: "/device", + RawQuery: values.Encode(), + } + http.Redirect(w, r, url.String(), http.StatusSeeOther) +} + +const userCodeCookieName = "user_code" + +type userCodeCookie struct { + UserCode string + UserName string +} + +func (d *deviceLogin) loginHandler(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + redirectBack(w, r, err.Error()) + return + } + + userCode := r.PostForm.Get("user_code") + if userCode == "" { + redirectBack(w, r, "missing user_code in request") + return + } + username := r.PostForm.Get("username") + if username == "" { + redirectBack(w, r, "missing username in request") + return + } + password := r.PostForm.Get("password") + if password == "" { + redirectBack(w, r, "missing password in request") + return + } + + if err := d.storage.CheckUsernamePasswordSimple(username, password); err != nil { + redirectBack(w, r, err.Error()) + return + } + state, err := d.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode) + if err != nil { + redirectBack(w, r, err.Error()) + return + } + + encoded, err := d.cookie.Encode(userCodeCookieName, userCodeCookie{userCode, username}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + cookie := &http.Cookie{ + Name: userCodeCookieName, + Value: encoded, + Path: "/", + } + http.SetCookie(w, cookie) + renderConfirmPage(w, username, state.ClientID, state.Scopes) +} + +func (d *deviceLogin) confirmHandler(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(userCodeCookieName) + if err != nil { + redirectBack(w, r, err.Error()) + return + } + data := new(userCodeCookie) + if err = d.cookie.Decode(userCodeCookieName, cookie.Value, &data); err != nil { + redirectBack(w, r, err.Error()) + return + } + if err = r.ParseForm(); err != nil { + redirectBack(w, r, err.Error()) + return + } + + action := r.Form.Get("action") + switch action { + case "allowed": + err = d.storage.CompleteDeviceAuthorization(r.Context(), data.UserCode, data.UserName) + case "denied": + err = d.storage.DenyDeviceAuthorization(r.Context(), data.UserCode) + default: + err = errors.New("action must be one of \"allow\" or \"deny\"") + } + if err != nil { + redirectBack(w, r, err.Error()) + return + } + + fmt.Fprintf(w, "Device authorization %s. You can now return to the device", action) +} diff --git a/example/server/exampleop/login.go b/example/server/exampleop/login.go index 5da86d1..c014c9a 100644 --- a/example/server/exampleop/login.go +++ b/example/server/exampleop/login.go @@ -3,45 +3,11 @@ package exampleop import ( "context" "fmt" - "html/template" "net/http" "github.com/gorilla/mux" ) -const ( - queryAuthRequestID = "authRequestID" -) - -var loginTmpl, _ = template.New("login").Parse(` - - - - - Login - - -
- - - -
- - -
- -
- - -
- -

{{.Error}}

- - -
- - `) - type login struct { authenticate authenticate router *mux.Router @@ -74,23 +40,19 @@ func (l *login) loginHandler(w http.ResponseWriter, r *http.Request) { return } // the oidc package will pass the id of the auth request as query parameter - // we will use this id through the login process and therefore pass it to the login page + // we will use this id through the login process and therefore pass it to the login page renderLogin(w, r.FormValue(queryAuthRequestID), nil) } func renderLogin(w http.ResponseWriter, id string, err error) { - var errMsg string - if err != nil { - errMsg = err.Error() - } data := &struct { ID string Error string }{ ID: id, - Error: errMsg, + Error: errMsg(err), } - err = loginTmpl.Execute(w, data) + err = templates.ExecuteTemplate(w, "login", data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index d3a450c..b46be7f 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "log" "net/http" + "time" "github.com/gorilla/mux" "golang.org/x/text/language" @@ -27,7 +28,8 @@ func init() { type Storage interface { op.Storage - CheckUsernamePassword(username, password, id string) error + authenticate + deviceAuthenticate } // SetupServer creates an OIDC server with Issuer=http://localhost: @@ -62,6 +64,9 @@ func SetupServer(ctx context.Context, issuer string, storage Storage) *mux.Route // so we will direct all calls to /login to the login UI router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.PathPrefix("/device").Subrouter() + registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter()) + // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // is served on the correct path // @@ -99,6 +104,13 @@ func newOP(ctx context.Context, storage op.Storage, issuer string, key [32]byte) // this example has only static texts (in English), so we'll set the here accordingly SupportedUILocales: []language.Tag{language.English}, + + DeviceAuthorization: op.DeviceAuthorizationConfig{ + Lifetime: 5 * time.Minute, + PollInterval: 5 * time.Second, + UserFormURL: issuer + "device", + UserCode: op.UserCodeBase20, + }, } handler, err := op.NewOpenIDProvider(ctx, issuer, config, storage, //we must explicitly allow the use of the http issuer diff --git a/example/server/exampleop/templates.go b/example/server/exampleop/templates.go new file mode 100644 index 0000000..5b5c966 --- /dev/null +++ b/example/server/exampleop/templates.go @@ -0,0 +1,26 @@ +package exampleop + +import ( + "embed" + "html/template" + + "github.com/sirupsen/logrus" +) + +var ( + //go:embed templates + templateFS embed.FS + templates = template.Must(template.ParseFS(templateFS, "templates/*.html")) +) + +const ( + queryAuthRequestID = "authRequestID" +) + +func errMsg(err error) string { + if err == nil { + return "" + } + logrus.Error(err) + return err.Error() +} diff --git a/example/server/exampleop/templates/confirm_device.html b/example/server/exampleop/templates/confirm_device.html new file mode 100644 index 0000000..a6bcdad --- /dev/null +++ b/example/server/exampleop/templates/confirm_device.html @@ -0,0 +1,25 @@ +{{ define "confirm_device" -}} + + + + + Confirm device authorization + + + +

Welcome back {{.Username}}!

+

+ You are about to grant device {{.ClientID}} access to the following scopes: {{.Scopes}}. +

+ + + + +{{- end }} diff --git a/example/server/exampleop/templates/device_login.html b/example/server/exampleop/templates/device_login.html new file mode 100644 index 0000000..cc5b00b --- /dev/null +++ b/example/server/exampleop/templates/device_login.html @@ -0,0 +1,29 @@ +{{ define "device_login" -}} + + + + + Login + + +
+ + + +
+ + +
+ +
+ + +
+ +

{{.Error}}

+ + +
+ + +{{- end }} diff --git a/example/server/exampleop/templates/login.html b/example/server/exampleop/templates/login.html new file mode 100644 index 0000000..b048211 --- /dev/null +++ b/example/server/exampleop/templates/login.html @@ -0,0 +1,29 @@ +{{ define "login" -}} + + + + + Login + + +
+ + + +
+ + +
+ +
+ + +
+ +

{{.Error}}

+ + +
+ +` +{{- end }} \ No newline at end of file diff --git a/example/server/exampleop/templates/usercode.html b/example/server/exampleop/templates/usercode.html new file mode 100644 index 0000000..fb8fa7f --- /dev/null +++ b/example/server/exampleop/templates/usercode.html @@ -0,0 +1,21 @@ +{{ define "usercode" -}} + + + + + Device authorization + + +
+

Device authorization

+
+ + +
+

{{.Error}}

+ + +
+ + +{{- end }} diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index 662132c..b49ce1b 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -44,6 +44,8 @@ type Storage struct { services map[string]Service refreshTokens map[string]*RefreshToken signingKey signingKey + deviceCodes map[string]deviceAuthorizationEntry + userCodes map[string]string } type signingKey struct { @@ -105,6 +107,8 @@ func NewStorage(userStore UserStore) *Storage { algorithm: jose.RS256, key: key, }, + deviceCodes: make(map[string]deviceAuthorizationEntry), + userCodes: make(map[string]string), } } @@ -135,6 +139,17 @@ func (s *Storage) CheckUsernamePassword(username, password, id string) error { return fmt.Errorf("username or password wrong") } +func (s *Storage) CheckUsernamePasswordSimple(username, password string) error { + s.lock.Lock() + defer s.lock.Unlock() + + user := s.userStore.GetUserByUsername(username) + if user != nil && user.Password == password { + return nil + } + return fmt.Errorf("username or password wrong") +} + // CreateAuthRequest implements the op.Storage interface // it will be called after parsing and validation of the authentication request func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) { @@ -735,3 +750,85 @@ func appendClaim(claims map[string]interface{}, claim string, value interface{}) claims[claim] = value return claims } + +type deviceAuthorizationEntry struct { + deviceCode string + userCode string + state *op.DeviceAuthorizationState +} + +func (s *Storage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.clients[clientID]; !ok { + return errors.New("client not found") + } + + if _, ok := s.userCodes[userCode]; ok { + return op.ErrDuplicateUserCode + } + + s.deviceCodes[deviceCode] = deviceAuthorizationEntry{ + deviceCode: deviceCode, + userCode: userCode, + state: &op.DeviceAuthorizationState{ + ClientID: clientID, + Scopes: scopes, + Expires: expires, + }, + } + + s.userCodes[userCode] = deviceCode + return nil +} + +func (s *Storage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.lock.Lock() + defer s.lock.Unlock() + + entry, ok := s.deviceCodes[deviceCode] + if !ok || entry.state.ClientID != clientID { + return nil, errors.New("device code not found for client") // is there a standard not found error in the framework? + } + + return entry.state, nil +} + +func (s *Storage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) { + s.lock.Lock() + defer s.lock.Unlock() + + entry, ok := s.deviceCodes[s.userCodes[userCode]] + if !ok { + return nil, errors.New("user code not found") + } + + return entry.state, nil +} + +func (s *Storage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error { + s.lock.Lock() + defer s.lock.Unlock() + + entry, ok := s.deviceCodes[s.userCodes[userCode]] + if !ok { + return errors.New("user code not found") + } + + entry.state.Subject = subject + entry.state.Done = true + return nil +} + +func (s *Storage) DenyDeviceAuthorization(ctx context.Context, userCode string) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.deviceCodes[s.userCodes[userCode]].state.Denied = true + return nil +} diff --git a/pkg/client/client.go b/pkg/client/client.go index 077baf2..b9ae008 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -1,6 +1,8 @@ package client import ( + "context" + "encoding/json" "errors" "fmt" "io" @@ -186,3 +188,94 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti IssuedAt: oidc.Time(iat), }, signer) } + +type DeviceAuthorizationCaller interface { + GetDeviceAuthorizationEndpoint() string + HttpClient() *http.Client +} + +func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { + req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) + if err != nil { + return nil, err + } + if request.ClientSecret != "" { + req.SetBasicAuth(request.ClientID, request.ClientSecret) + } + + resp := new(oidc.DeviceAuthorizationResponse) + if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil { + return nil, err + } + return resp, nil +} + +type DeviceAccessTokenRequest struct { + *oidc.ClientCredentialsRequest + oidc.DeviceAccessTokenRequest +} + +func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { + req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil) + if err != nil { + return nil, err + } + if request.ClientSecret != "" { + req.SetBasicAuth(request.ClientID, request.ClientSecret) + } + + httpResp, err := caller.HttpClient().Do(req) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + + resp := new(struct { + *oidc.AccessTokenResponse + *oidc.Error + }) + if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil { + return nil, err + } + + if httpResp.StatusCode == http.StatusOK { + return resp.AccessTokenResponse, nil + } + + return nil, resp.Error +} + +func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { + for { + timer := time.After(interval) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer: + } + + ctx, cancel := context.WithTimeout(ctx, interval) + defer cancel() + + resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller) + if err == nil { + return resp, nil + } + if errors.Is(err, context.DeadlineExceeded) { + interval += 5 * time.Second + } + var target *oidc.Error + if !errors.As(err, &target) { + return nil, err + } + switch target.ErrorType { + case oidc.AuthorizationPending: + continue + case oidc.SlowDown: + interval += 5 * time.Second + continue + default: + return nil, err + } + } +} diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go new file mode 100644 index 0000000..73b67ca --- /dev/null +++ b/pkg/client/rp/device.go @@ -0,0 +1,62 @@ +package rp + +import ( + "context" + "fmt" + "time" + + "github.com/zitadel/oidc/v2/pkg/client" + "github.com/zitadel/oidc/v2/pkg/oidc" +) + +func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) { + confg := rp.OAuthConfig() + req := &oidc.ClientCredentialsRequest{ + GrantType: oidc.GrantTypeDeviceCode, + Scope: scopes, + ClientID: confg.ClientID, + ClientSecret: confg.ClientSecret, + } + + if signer := rp.Signer(); signer != nil { + assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer) + if err != nil { + return nil, fmt.Errorf("failed to build assertion: %w", err) + } + req.ClientAssertion = assertion + req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion + } + + return req, nil +} + +// DeviceAuthorization starts a new Device Authorization flow as defined +// in RFC 8628, section 3.1 and 3.2: +// https://www.rfc-editor.org/rfc/rfc8628#section-3.1 +func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { + req, err := newDeviceClientCredentialsRequest(scopes, rp) + if err != nil { + return nil, err + } + + return client.CallDeviceAuthorizationEndpoint(req, rp) +} + +// DeviceAccessToken attempts to obtain tokens from a Device Authorization, +// by means of polling as defined in RFC, section 3.3 and 3.4: +// https://www.rfc-editor.org/rfc/rfc8628#section-3.4 +func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) { + req := &client.DeviceAccessTokenRequest{ + DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{ + GrantType: oidc.GrantTypeDeviceCode, + DeviceCode: deviceCode, + }, + } + + req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp) + if err != nil { + return nil, err + } + + return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp}) +} diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index d2e3cf7..96fe219 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -59,6 +59,10 @@ type RelyingParty interface { // UserinfoEndpoint returns the userinfo UserinfoEndpoint() string + // GetDeviceAuthorizationEndpoint returns the enpoint which can + // be used to start a DeviceAuthorization flow. + GetDeviceAuthorizationEndpoint() string + // IDTokenVerifier returns the verifier interface used for oidc id_token verification IDTokenVerifier() IDTokenVerifier // ErrorHandler returns the handler used for callback errors @@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string { return rp.endpoints.UserinfoURL } +func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string { + return rp.endpoints.DeviceAuthorizationURL +} + func (rp *relyingParty) GetEndSessionEndpoint() string { return rp.endpoints.EndSessionURL } @@ -495,11 +503,12 @@ type OptionFunc func(RelyingParty) type Endpoints struct { oauth2.Endpoint - IntrospectURL string - UserinfoURL string - JKWsURL string - EndSessionURL string - RevokeURL string + IntrospectURL string + UserinfoURL string + JKWsURL string + EndSessionURL string + RevokeURL string + DeviceAuthorizationURL string } func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { @@ -509,11 +518,12 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { AuthStyle: oauth2.AuthStyleAutoDetect, TokenURL: discoveryConfig.TokenEndpoint, }, - IntrospectURL: discoveryConfig.IntrospectionEndpoint, - UserinfoURL: discoveryConfig.UserinfoEndpoint, - JKWsURL: discoveryConfig.JwksURI, - EndSessionURL: discoveryConfig.EndSessionEndpoint, - RevokeURL: discoveryConfig.RevocationEndpoint, + IntrospectURL: discoveryConfig.IntrospectionEndpoint, + UserinfoURL: discoveryConfig.UserinfoEndpoint, + JKWsURL: discoveryConfig.JwksURI, + EndSessionURL: discoveryConfig.EndSessionEndpoint, + RevokeURL: discoveryConfig.RevocationEndpoint, + DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint, } } diff --git a/pkg/oidc/device_authorization.go b/pkg/oidc/device_authorization.go new file mode 100644 index 0000000..68b8efa --- /dev/null +++ b/pkg/oidc/device_authorization.go @@ -0,0 +1,29 @@ +package oidc + +// DeviceAuthorizationRequest implements +// https://www.rfc-editor.org/rfc/rfc8628#section-3.1, +// 3.1 Device Authorization Request. +type DeviceAuthorizationRequest struct { + Scopes SpaceDelimitedArray `schema:"scope"` + ClientID string `schema:"client_id"` +} + +// DeviceAuthorizationResponse implements +// https://www.rfc-editor.org/rfc/rfc8628#section-3.2 +// 3.2. Device Authorization Response. +type DeviceAuthorizationResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval,omitempty"` +} + +// DeviceAccessTokenRequest implements +// https://www.rfc-editor.org/rfc/rfc8628#section-3.4, +// Device Access Token Request. +type DeviceAccessTokenRequest struct { + GrantType GrantType `json:"grant_type" schema:"grant_type"` + DeviceCode string `json:"device_code" schema:"device_code"` +} diff --git a/pkg/oidc/discovery.go b/pkg/oidc/discovery.go index fbc417b..3574101 100644 --- a/pkg/oidc/discovery.go +++ b/pkg/oidc/discovery.go @@ -30,6 +30,8 @@ type DiscoveryConfiguration struct { // EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP. EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"` + // CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client. CheckSessionIframe string `json:"check_session_iframe,omitempty"` diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 5797a59..79acecd 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -18,6 +18,14 @@ const ( InteractionRequired errorType = "interaction_required" LoginRequired errorType = "login_required" RequestNotSupported errorType = "request_not_supported" + + // Additional error codes as defined in + // https://www.rfc-editor.org/rfc/rfc8628#section-3.5 + // Device Access Token Response + AuthorizationPending errorType = "authorization_pending" + SlowDown errorType = "slow_down" + AccessDenied errorType = "access_denied" + ExpiredToken errorType = "expired_token" ) var ( @@ -77,6 +85,32 @@ var ( ErrorType: RequestNotSupported, } } + + // Device Access Token errors: + ErrAuthorizationPending = func() *Error { + return &Error{ + ErrorType: AuthorizationPending, + Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.", + } + } + ErrSlowDown = func() *Error { + return &Error{ + ErrorType: SlowDown, + Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.", + } + } + ErrAccessDenied = func() *Error { + return &Error{ + ErrorType: AccessDenied, + Description: "The authorization request was denied.", + } + } + ErrExpiredDeviceCode = func() *Error { + return &Error{ + ErrorType: ExpiredToken, + Description: "The \"device_code\" has expired.", + } + } ) type Error struct { diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 6d8f186..78bd658 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -27,6 +27,9 @@ const ( // GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code GrantTypeImplicit GrantType = "implicit" + // GrantTypeDeviceCode + GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code" + // ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer` // used for the OAuth JWT Profile Client Authentication ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" @@ -35,7 +38,7 @@ const ( var AllGrantTypes = []GrantType{ GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials, GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit, - ClientAssertionTypeJWTAssertion, + GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion, } type GrantType string diff --git a/pkg/op/client.go b/pkg/op/client.go index e8a3347..1f5e1c9 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -1,8 +1,13 @@ package op import ( + "context" + "errors" + "net/http" + "net/url" "time" + httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" ) @@ -57,3 +62,95 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT func IsConfidentialType(c Client) bool { return c.ApplicationType() == ApplicationTypeWeb } + +var ( + ErrInvalidAuthHeader = errors.New("invalid basic auth header") + ErrNoClientCredentials = errors.New("no client credentials provided") + ErrMissingClientID = errors.New("client_id missing from request") +) + +type ClientJWTProfile interface { + JWTProfileVerifier(context.Context) JWTProfileVerifier +} + +func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { + if ca.ClientAssertion == "" { + return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) + } + + profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx)) + if err != nil { + return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed") + } + return profile.Issuer, nil +} + +func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) { + clientID, clientSecret, ok := r.BasicAuth() + if !ok { + return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) + } + clientID, err = url.QueryUnescape(clientID) + if err != nil { + return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader) + } + clientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader) + } + if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil { + return "", oidc.ErrUnauthorizedClient().WithParent(err) + } + return clientID, nil +} + +type ClientProvider interface { + Decoder() httphelper.Decoder + Storage() Storage +} + +type clientData struct { + ClientID string `schema:"client_id"` + oidc.ClientAssertionParams +} + +// ClientIDFromRequest parses the request form and tries to obtain the client ID +// and reports if it is authenticated, using a JWT or static client secrets over +// http basic auth. +// +// If the Provider implements IntrospectorJWTProfile and "client_assertion" is +// present in the form data, JWT assertion will be verified and the +// client ID is taken from there. +// If any of them is absent, basic auth is attempted. +// In absence of basic auth data, the unauthenticated client id from the form +// data is returned. +// +// If no client id can be obtained by any method, oidc.ErrInvalidClient +// is returned with ErrMissingClientID wrapped in it. +func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) { + err = r.ParseForm() + if err != nil { + return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) + } + + data := new(clientData) + if err = p.Decoder().Decode(data, r.PostForm); err != nil { + return "", false, err + } + + JWTProfile, ok := p.(ClientJWTProfile) + if ok { + clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile) + } + if !ok || errors.Is(err, ErrNoClientCredentials) { + clientID, err = ClientBasicAuth(r, p.Storage()) + } + if err == nil { + return clientID, true, nil + } + + if data.ClientID == "" { + return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID) + } + return data.ClientID, false, nil +} diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go new file mode 100644 index 0000000..1af4157 --- /dev/null +++ b/pkg/op/client_test.go @@ -0,0 +1,253 @@ +package op_test + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/golang/mock/gomock" + "github.com/gorilla/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + httphelper "github.com/zitadel/oidc/v2/pkg/http" + "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v2/pkg/op/mock" +) + +type testClientJWTProfile struct{} + +func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil } + +func TestClientJWTAuth(t *testing.T) { + type args struct { + ctx context.Context + ca oidc.ClientAssertionParams + verifier op.ClientJWTProfile + } + tests := []struct { + name string + args args + wantClientID string + wantErr error + }{ + { + name: "empty assertion", + args: args{ + context.Background(), + oidc.ClientAssertionParams{}, + testClientJWTProfile{}, + }, + wantErr: op.ErrNoClientCredentials, + }, + { + name: "verification error", + args: args{ + context.Background(), + oidc.ClientAssertionParams{ + ClientAssertion: "foo", + }, + testClientJWTProfile{}, + }, + wantErr: oidc.ErrParse, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantClientID, gotClientID) + }) + } +} + +func TestClientBasicAuth(t *testing.T) { + errWrong := errors.New("wrong secret") + + type args struct { + username string + password string + } + tests := []struct { + name string + args *args + storage op.Storage + wantClientID string + wantErr error + }{ + { + name: "no args", + wantErr: op.ErrNoClientCredentials, + }, + { + name: "username unescape err", + args: &args{ + username: "%", + password: "bar", + }, + wantErr: op.ErrInvalidAuthHeader, + }, + { + name: "password unescape err", + args: &args{ + username: "foo", + password: "%", + }, + wantErr: op.ErrInvalidAuthHeader, + }, + { + name: "auth error", + args: &args{ + username: "foo", + password: "wrong", + }, + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong) + return s + }(), + wantErr: errWrong, + }, + { + name: "auth error", + args: &args{ + username: "foo", + password: "bar", + }, + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + return s + }(), + wantClientID: "foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/foo", nil) + if tt.args != nil { + r.SetBasicAuth(tt.args.username, tt.args.password) + } + + gotClientID, err := op.ClientBasicAuth(r, tt.storage) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantClientID, gotClientID) + }) + } +} + +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { + return 0, io.ErrNoProgress +} + +type testClientProvider struct { + storage op.Storage +} + +func (testClientProvider) Decoder() httphelper.Decoder { + return schema.NewDecoder() +} + +func (p testClientProvider) Storage() op.Storage { + return p.storage +} + +func TestClientIDFromRequest(t *testing.T) { + type args struct { + body io.Reader + p op.ClientProvider + } + type basicAuth struct { + username string + password string + } + tests := []struct { + name string + args args + basicAuth *basicAuth + wantClientID string + wantAuthenticated bool + wantErr bool + }{ + { + name: "parse error", + args: args{ + body: errReader{}, + }, + wantErr: true, + }, + { + name: "unauthenticated", + args: args{ + body: strings.NewReader( + url.Values{ + "client_id": []string{"foo"}, + }.Encode(), + ), + p: testClientProvider{ + storage: mock.NewStorage(t), + }, + }, + wantClientID: "foo", + wantAuthenticated: false, + }, + { + name: "authenticated", + args: args{ + body: strings.NewReader( + url.Values{}.Encode(), + ), + p: testClientProvider{ + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + return s + }(), + }, + }, + basicAuth: &basicAuth{ + username: "foo", + password: "bar", + }, + wantClientID: "foo", + wantAuthenticated: true, + }, + { + name: "missing client id", + args: args{ + body: strings.NewReader( + url.Values{}.Encode(), + ), + p: testClientProvider{ + storage: mock.NewStorage(t), + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if tt.basicAuth != nil { + r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password) + } + + gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantClientID, gotClientID) + assert.Equal(t, tt.wantAuthenticated, gotAuthenticated) + }) + } +} diff --git a/pkg/op/config.go b/pkg/op/config.go index c40fa2d..c40ed39 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -27,6 +27,7 @@ type Configuration interface { RevocationEndpoint() Endpoint EndSessionEndpoint() Endpoint KeysEndpoint() Endpoint + DeviceAuthorizationEndpoint() Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool @@ -36,6 +37,7 @@ type Configuration interface { GrantTypeTokenExchangeSupported() bool GrantTypeJWTAuthorizationSupported() bool GrantTypeClientCredentialsSupported() bool + GrantTypeDeviceCodeSupported() bool IntrospectionAuthMethodPrivateKeyJWTSupported() bool IntrospectionEndpointSigningAlgorithmsSupported() []string RevocationAuthMethodPrivateKeyJWTSupported() bool @@ -44,6 +46,7 @@ type Configuration interface { RequestObjectSigningAlgorithmsSupported() []string SupportedUILocales() []language.Tag + DeviceAuthorization() DeviceAuthorizationConfig } type IssuerFromRequest func(r *http.Request) string diff --git a/pkg/op/device.go b/pkg/op/device.go new file mode 100644 index 0000000..04c06f2 --- /dev/null +++ b/pkg/op/device.go @@ -0,0 +1,265 @@ +package op + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "math/big" + "net/http" + "strings" + "time" + + httphelper "github.com/zitadel/oidc/v2/pkg/http" + "github.com/zitadel/oidc/v2/pkg/oidc" +) + +type DeviceAuthorizationConfig struct { + Lifetime time.Duration + PollInterval time.Duration + UserFormURL string // the URL where the user must go to authorize the device + UserCode UserCodeConfig +} + +type UserCodeConfig struct { + CharSet string + CharAmount int + DashInterval int +} + +const ( + CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ" + CharSetDigits = "0123456789" +) + +var ( + UserCodeBase20 = UserCodeConfig{ + CharSet: CharSetBase20, + CharAmount: 8, + DashInterval: 4, + } + UserCodeDigits = UserCodeConfig{ + CharSet: CharSetDigits, + CharAmount: 9, + DashInterval: 3, + } +) + +func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if err := DeviceAuthorization(w, r, o); err != nil { + RequestError(w, r, err) + } + } +} + +func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { + storage, err := assertDeviceStorage(o.Storage()) + if err != nil { + return err + } + + req, err := ParseDeviceCodeRequest(r, o) + if err != nil { + return err + } + + config := o.DeviceAuthorization() + + deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) + if err != nil { + return err + } + userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval) + if err != nil { + return err + } + + expires := time.Now().Add(config.Lifetime) + err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) + if err != nil { + return err + } + + response := &oidc.DeviceAuthorizationResponse{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: config.UserFormURL, + ExpiresIn: int(config.Lifetime / time.Second), + Interval: int(config.PollInterval / time.Second), + } + + response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode) + + httphelper.MarshalJSON(w, response) + return nil +} + +func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { + clientID, _, err := ClientIDFromRequest(r, o) + if err != nil { + return nil, err + } + + req := new(oidc.DeviceAuthorizationRequest) + if err := o.Decoder().Decode(req, r.Form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err) + } + req.ClientID = clientID + + return req, nil +} + +// 16 bytes gives 128 bit of entropy. +// results in a 22 character base64 encoded string. +const RecommendedDeviceCodeBytes = 16 + +func NewDeviceCode(nBytes int) (string, error) { + bytes := make([]byte, nBytes) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("%w getting entropy for device code", err) + } + return base64.RawURLEncoding.EncodeToString(bytes), nil +} + +func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) { + var buf strings.Builder + if dashInterval > 0 { + buf.Grow(charAmount + charAmount/dashInterval - 1) + } else { + buf.Grow(charAmount) + } + + max := big.NewInt(int64(len(charSet))) + + for i := 0; i < charAmount; i++ { + if dashInterval != 0 && i != 0 && i%dashInterval == 0 { + buf.WriteByte('-') + } + + bi, err := rand.Int(rand.Reader, max) + if err != nil { + return "", fmt.Errorf("%w getting entropy for user code", err) + } + + buf.WriteRune(charSet[int(bi.Int64())]) + } + + return buf.String(), nil +} + +type deviceAccessTokenRequest struct { + subject string + audience []string + scopes []string +} + +func (r *deviceAccessTokenRequest) GetSubject() string { + return r.subject +} + +func (r *deviceAccessTokenRequest) GetAudience() []string { + return r.audience +} + +func (r *deviceAccessTokenRequest) GetScopes() []string { + return r.scopes +} + +func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + if err := deviceAccessToken(w, r, exchanger); err != nil { + RequestError(w, r, err) + } +} + +func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error { + // use a limited context timeout shorter as the default + // poll interval of 5 seconds. + ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second) + defer cancel() + r = r.WithContext(ctx) + + clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger) + if err != nil { + return err + } + + req, err := ParseDeviceAccessTokenRequest(r, exchanger) + if err != nil { + return err + } + state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) + if err != nil { + return err + } + + client, err := exchanger.Storage().GetClientByClientID(ctx, clientID) + if err != nil { + return err + } + if clientAuthenticated != IsConfidentialType(client) { + return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials). + WithDescription("confidential client requires authentication") + } + + tokenRequest := &deviceAccessTokenRequest{ + subject: state.Subject, + audience: []string{clientID}, + scopes: state.Scopes, + } + resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client) + if err != nil { + return err + } + + httphelper.MarshalJSON(w, resp) + return nil +} + +func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) { + req := new(oidc.DeviceAccessTokenRequest) + if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil { + return nil, err + } + return req, nil +} + +func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { + storage, err := assertDeviceStorage(exchanger.Storage()) + if err != nil { + return nil, err + } + + state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode) + if errors.Is(err, context.DeadlineExceeded) { + return nil, oidc.ErrSlowDown().WithParent(err) + } + if err != nil { + return nil, oidc.ErrAccessDenied().WithParent(err) + } + if state.Denied { + return state, oidc.ErrAccessDenied() + } + if state.Done { + return state, nil + } + if time.Now().After(state.Expires) { + return state, oidc.ErrExpiredDeviceCode() + } + return state, oidc.ErrAuthorizationPending() +} + +func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) { + accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, client, "") + if err != nil { + return nil, err + } + + return &oidc.AccessTokenResponse{ + AccessToken: accessToken, + RefreshToken: refreshToken, + TokenType: oidc.BearerToken, + ExpiresIn: uint64(validity.Seconds()), + }, nil +} diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go new file mode 100644 index 0000000..ca68759 --- /dev/null +++ b/pkg/op/device_test.go @@ -0,0 +1,453 @@ +package op_test + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "io" + mr "math/rand" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v2/example/server/storage" + "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v2/pkg/op" + "golang.org/x/text/language" +) + +var testProvider op.OpenIDProvider + +const ( + testIssuer = "https://localhost:9998/" + pathLoggedOut = "/logged-out" +) + +func init() { + config := &op.Config{ + CryptoKey: sha256.Sum256([]byte("test")), + DefaultLogoutRedirectURI: pathLoggedOut, + CodeMethodS256: true, + AuthMethodPost: true, + AuthMethodPrivateKeyJWT: true, + GrantTypeRefreshToken: true, + RequestObjectSupported: true, + SupportedUILocales: []language.Tag{language.English}, + DeviceAuthorization: op.DeviceAuthorizationConfig{ + Lifetime: 5 * time.Minute, + PollInterval: 5 * time.Second, + UserFormURL: testIssuer + "device", + UserCode: op.UserCodeBase20, + }, + } + + storage.RegisterClients( + storage.NativeClient("native"), + storage.WebClient("web", "secret"), + storage.WebClient("api", "secret"), + ) + + var err error + testProvider, err = op.NewOpenIDProvider(context.TODO(), testIssuer, config, + storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(), + ) + if err != nil { + panic(err) + } +} + +func Test_deviceAuthorizationHandler(t *testing.T) { + req := &oidc.DeviceAuthorizationRequest{ + Scopes: []string{"foo", "bar"}, + ClientID: "web", + } + values := make(url.Values) + testProvider.Encoder().Encode(req, values) + body := strings.NewReader(values.Encode()) + + r := httptest.NewRequest(http.MethodPost, "/", body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + w := httptest.NewRecorder() + + runWithRandReader(mr.New(mr.NewSource(1)), func() { + op.DeviceAuthorizationHandler(testProvider)(w, r) + }) + + result := w.Result() + + assert.Less(t, result.StatusCode, 300) + + got, _ := io.ReadAll(result.Body) + assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got)) +} + +func TestParseDeviceCodeRequest(t *testing.T) { + tests := []struct { + name string + req *oidc.DeviceAuthorizationRequest + wantErr bool + }{ + { + name: "empty request", + wantErr: true, + }, + /* decoding a SpaceDelimitedArray is broken + https://github.com/zitadel/oidc/issues/295 + { + name: "success", + req: &oidc.DeviceAuthorizationRequest{ + Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, + ClientID: "web", + }, + }, + */ + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var body io.Reader + if tt.req != nil { + values := make(url.Values) + testProvider.Encoder().Encode(tt.req, values) + body = strings.NewReader(values.Encode()) + } + + r := httptest.NewRequest(http.MethodPost, "/", body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + got, err := op.ParseDeviceCodeRequest(r, testProvider) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.req, got) + }) + } +} + +func runWithRandReader(r io.Reader, f func()) { + originalReader := rand.Reader + rand.Reader = r + defer func() { + rand.Reader = originalReader + }() + + f() +} + +func TestNewDeviceCode(t *testing.T) { + t.Run("reader error", func(t *testing.T) { + runWithRandReader(errReader{}, func() { + _, err := op.NewDeviceCode(16) + require.Error(t, err) + }) + }) + + t.Run("different lengths, rand reader", func(t *testing.T) { + for i := 1; i <= 32; i++ { + got, err := op.NewDeviceCode(i) + require.NoError(t, err) + assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i)) + } + }) + +} + +func TestNewUserCode(t *testing.T) { + type args struct { + charset []rune + charAmount int + dashInterval int + } + tests := []struct { + name string + args args + reader io.Reader + want string + wantErr bool + }{ + { + name: "reader error", + args: args{ + charset: []rune(op.CharSetBase20), + charAmount: 8, + dashInterval: 4, + }, + reader: errReader{}, + wantErr: true, + }, + { + name: "base20", + args: args{ + charset: []rune(op.CharSetBase20), + charAmount: 8, + dashInterval: 4, + }, + reader: mr.New(mr.NewSource(1)), + want: "XKCD-HTTD", + }, + { + name: "digits", + args: args{ + charset: []rune(op.CharSetDigits), + charAmount: 9, + dashInterval: 3, + }, + reader: mr.New(mr.NewSource(1)), + want: "271-256-225", + }, + { + name: "no dashes", + args: args{ + charset: []rune(op.CharSetDigits), + charAmount: 9, + }, + reader: mr.New(mr.NewSource(1)), + want: "271256225", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + runWithRandReader(tt.reader, func() { + got, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval) + if tt.wantErr { + require.ErrorIs(t, err, io.ErrNoProgress) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + + }) + } + + t.Run("crypto/rand", func(t *testing.T) { + const testN = 100000 + + for _, c := range []op.UserCodeConfig{op.UserCodeBase20, op.UserCodeDigits} { + t.Run(c.CharSet, func(t *testing.T) { + results := make(map[string]int) + + for i := 0; i < testN; i++ { + code, err := op.NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval) + require.NoError(t, err) + results[code]++ + } + + t.Log(results) + + var duplicates int + for code, count := range results { + assert.Less(t, count, 3, code) + if count == 2 { + duplicates++ + } + } + + }) + } + }) +} + +func BenchmarkNewUserCode(b *testing.B) { + type args struct { + charset []rune + charAmount int + dashInterval int + } + tests := []struct { + name string + args args + reader io.Reader + }{ + { + name: "math rand, base20", + args: args{ + charset: []rune(op.CharSetBase20), + charAmount: 8, + dashInterval: 4, + }, + reader: mr.New(mr.NewSource(1)), + }, + { + name: "math rand, digits", + args: args{ + charset: []rune(op.CharSetDigits), + charAmount: 9, + dashInterval: 3, + }, + reader: mr.New(mr.NewSource(1)), + }, + { + name: "crypto rand, base20", + args: args{ + charset: []rune(op.CharSetBase20), + charAmount: 8, + dashInterval: 4, + }, + reader: rand.Reader, + }, + { + name: "crypto rand, digits", + args: args{ + charset: []rune(op.CharSetDigits), + charAmount: 9, + dashInterval: 3, + }, + reader: rand.Reader, + }, + } + for _, tt := range tests { + runWithRandReader(tt.reader, func() { + b.Run(tt.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval) + require.NoError(b, err) + } + }) + + }) + } +} + +func TestDeviceAccessToken(t *testing.T) { + storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"}) + storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim") + + values := make(url.Values) + values.Set("client_id", "native") + values.Set("grant_type", string(oidc.GrantTypeDeviceCode)) + values.Set("device_code", "qwerty") + + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode())) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + + op.DeviceAccessToken(w, r, testProvider) + + result := w.Result() + got, _ := io.ReadAll(result.Body) + t.Log(string(got)) + assert.Less(t, result.StatusCode, 300) + assert.NotEmpty(t, string(got)) +} + +func TestCheckDeviceAuthorizationState(t *testing.T) { + now := time.Now() + + storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"}) + storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"}) + storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"}) + storage.StoreDeviceAuthorization(context.Background(), "native", "expired", "expired", now.Add(-time.Minute), []string{"foo"}) + + storage.DenyDeviceAuthorization(context.Background(), "denied") + storage.CompleteDeviceAuthorization(context.Background(), "completed", "tim") + + exceededCtx, cancel := context.WithTimeout(context.Background(), -time.Second) + defer cancel() + + type args struct { + ctx context.Context + clientID string + deviceCode string + } + tests := []struct { + name string + args args + want *op.DeviceAuthorizationState + wantErr error + }{ + { + name: "pending", + args: args{ + ctx: context.Background(), + clientID: "native", + deviceCode: "pending", + }, + want: &op.DeviceAuthorizationState{ + ClientID: "native", + Scopes: []string{"foo"}, + Expires: now.Add(time.Minute), + }, + wantErr: oidc.ErrAuthorizationPending(), + }, + { + name: "slow down", + args: args{ + ctx: exceededCtx, + clientID: "native", + deviceCode: "ok", + }, + wantErr: oidc.ErrSlowDown(), + }, + { + name: "wrong client", + args: args{ + ctx: context.Background(), + clientID: "foo", + deviceCode: "ok", + }, + wantErr: oidc.ErrAccessDenied(), + }, + { + name: "denied", + args: args{ + ctx: context.Background(), + clientID: "native", + deviceCode: "denied", + }, + want: &op.DeviceAuthorizationState{ + ClientID: "native", + Scopes: []string{"foo"}, + Expires: now.Add(time.Minute), + Denied: true, + }, + wantErr: oidc.ErrAccessDenied(), + }, + { + name: "completed", + args: args{ + ctx: context.Background(), + clientID: "native", + deviceCode: "completed", + }, + want: &op.DeviceAuthorizationState{ + ClientID: "native", + Scopes: []string{"foo"}, + Expires: now.Add(time.Minute), + Subject: "tim", + Done: true, + }, + }, + { + name: "expired", + args: args{ + ctx: context.Background(), + clientID: "native", + deviceCode: "expired", + }, + want: &op.DeviceAuthorizationState{ + ClientID: "native", + Scopes: []string{"foo"}, + Expires: now.Add(-time.Minute), + }, + wantErr: oidc.ErrExpiredDeviceCode(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.CheckDeviceAuthorizationState(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, testProvider) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 9a25afc..26f89eb 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -44,6 +44,7 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer), EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer), JwksURI: config.KeysEndpoint().Absolute(issuer), + DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer), ScopesSupported: Scopes(config), ResponseTypesSupported: ResponseTypes(config), GrantTypesSupported: GrantTypes(config), @@ -92,6 +93,9 @@ func GrantTypes(c Configuration) []oidc.GrantType { if c.GrantTypeJWTAuthorizationSupported() { grantTypes = append(grantTypes, oidc.GrantTypeBearer) } + if c.GrantTypeDeviceCodeSupported() { + grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode) + } return grantTypes } diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index e1b07dd..2d0b8af 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -131,6 +131,7 @@ func Test_GrantTypes(t *testing.T) { c.EXPECT().GrantTypeTokenExchangeSupported().Return(false) c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false) c.EXPECT().GrantTypeClientCredentialsSupported().Return(false) + c.EXPECT().GrantTypeDeviceCodeSupported().Return(false) return c }(), }, @@ -148,6 +149,7 @@ func Test_GrantTypes(t *testing.T) { c.EXPECT().GrantTypeTokenExchangeSupported().Return(true) c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true) c.EXPECT().GrantTypeClientCredentialsSupported().Return(true) + c.EXPECT().GrantTypeDeviceCodeSupported().Return(false) return c }(), }, diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index fc3158a..44b5ceb 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -92,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported)) } +// DeviceAuthorization mocks base method. +func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeviceAuthorization") + ret0, _ := ret[0].(op.DeviceAuthorizationConfig) + return ret0 +} + +// DeviceAuthorization indicates an expected call of DeviceAuthorization. +func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization)) +} + +// DeviceAuthorizationEndpoint mocks base method. +func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint. +func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint)) +} + // EndSessionEndpoint mocks base method. func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { m.ctrl.T.Helper() @@ -120,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported)) } +// GrantTypeDeviceCodeSupported mocks base method. +func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported. +func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported)) +} + // GrantTypeJWTAuthorizationSupported mocks base method. func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool { m.ctrl.T.Helper() @@ -358,6 +400,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported)) } +// UserCodeFormEndpoint mocks base method. +func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UserCodeFormEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint. +func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint)) +} + // UserinfoEndpoint mocks base method. func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { m.ctrl.T.Helper() diff --git a/pkg/op/op.go b/pkg/op/op.go index 699fb45..2859722 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -27,17 +27,19 @@ const ( defaultRevocationEndpoint = "revoke" defaultEndSessionEndpoint = "end_session" defaultKeysEndpoint = "keys" + defaultDeviceAuthzEndpoint = "/device_authorization" ) var ( DefaultEndpoints = &endpoints{ - Authorization: NewEndpoint(defaultAuthorizationEndpoint), - Token: NewEndpoint(defaultTokenEndpoint), - Introspection: NewEndpoint(defaultIntrospectEndpoint), - Userinfo: NewEndpoint(defaultUserinfoEndpoint), - Revocation: NewEndpoint(defaultRevocationEndpoint), - EndSession: NewEndpoint(defaultEndSessionEndpoint), - JwksURI: NewEndpoint(defaultKeysEndpoint), + Authorization: NewEndpoint(defaultAuthorizationEndpoint), + Token: NewEndpoint(defaultTokenEndpoint), + Introspection: NewEndpoint(defaultIntrospectEndpoint), + Userinfo: NewEndpoint(defaultUserinfoEndpoint), + Revocation: NewEndpoint(defaultRevocationEndpoint), + EndSession: NewEndpoint(defaultEndSessionEndpoint), + JwksURI: NewEndpoint(defaultKeysEndpoint), + DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), } defaultCORSOptions = cors.Options{ @@ -95,6 +97,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) + router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o)) return router } @@ -118,17 +121,19 @@ type Config struct { GrantTypeRefreshToken bool RequestObjectSupported bool SupportedUILocales []language.Tag + DeviceAuthorization DeviceAuthorizationConfig } type endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint + Authorization Endpoint + Token Endpoint + Introspection Endpoint + Userinfo Endpoint + Revocation Endpoint + EndSession Endpoint + CheckSessionIframe Endpoint + JwksURI Endpoint + DeviceAuthorization Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -145,6 +150,7 @@ type endpoints struct { // /revoke // /end_session // /keys +// /device_authorization // // This does not include login. Login is handled with a redirect that includes the // request ID. The redirect for logins is specified per-client by Client.LoginURL(). @@ -242,6 +248,10 @@ func (o *Provider) EndSessionEndpoint() Endpoint { return o.endpoints.EndSession } +func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { + return o.endpoints.DeviceAuthorization +} + func (o *Provider) KeysEndpoint() Endpoint { return o.endpoints.JwksURI } @@ -275,6 +285,11 @@ func (o *Provider) GrantTypeJWTAuthorizationSupported() bool { return true } +func (o *Provider) GrantTypeDeviceCodeSupported() bool { + _, ok := o.storage.(DeviceAuthorizationStorage) + return ok +} + func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool { return true } @@ -308,6 +323,10 @@ func (o *Provider) SupportedUILocales() []language.Tag { return o.config.SupportedUILocales } +func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig { + return o.config.DeviceAuthorization +} + func (o *Provider) Storage() Storage { return o.storage } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 1e19c76..ebab1c3 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -151,3 +151,50 @@ type EndSessionRequest struct { ClientID string RedirectURI string } + +var ErrDuplicateUserCode = errors.New("user code already exists") + +type DeviceAuthorizationState struct { + ClientID string + Scopes []string + Expires time.Time + Done bool + Subject string + Denied bool +} + +type DeviceAuthorizationStorage interface { + // StoreDeviceAuthorizationRequest stores a new device authorization request in the database. + // User code will be used by the user to complete the login flow and must be unique. + // ErrDuplicateUserCode signals the caller should try again with a new code. + // + // Note that user codes are low entropy keys and when many exist in the + // database, the change for collisions increases. Therefore implementers + // of this interface must make sure that user codes of expired authentication flows are purged, + // after some time. + StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error + + // GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database. + // The method is polled untill the the authorization is eighter Completed, Expired or Denied. + GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) + + // GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow, + // identified by the user code. + GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error) + + // CompleteDeviceAuthorization marks a device authorization entry as Completed, + // identified by userCode. The Subject is added to the state, so that + // GetDeviceAuthorizatonState can use it to create a new Access Token. + CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error + + // DenyDeviceAuthorization marks a device authorization entry as Denied. + DenyDeviceAuthorization(ctx context.Context, userCode string) error +} + +func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) { + storage, ok := s.(DeviceAuthorizationStorage) + if !ok { + return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported") + } + return storage, nil +} diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index dfc8954..e7ca7c4 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net/http" - "net/url" httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" @@ -50,38 +49,19 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto } func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) { - err = r.ParseForm() + clientID, authenticated, err := ClientIDFromRequest(r, introspector) if err != nil { - return "", "", errors.New("unable to parse request") + return "", "", err } - req := new(struct { - oidc.IntrospectionRequest - oidc.ClientAssertionParams - }) + if !authenticated { + return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) + } + + req := new(oidc.IntrospectionRequest) err = introspector.Decoder().Decode(req, r.Form) if err != nil { return "", "", errors.New("unable to parse request") } - if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" { - profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier(r.Context())) - if err == nil { - return req.Token, profile.Issuer, nil - } - } - clientID, clientSecret, ok := r.BasicAuth() - if ok { - clientID, err = url.QueryUnescape(clientID) - if err != nil { - return "", "", errors.New("invalid basic auth header") - } - clientSecret, err = url.QueryUnescape(clientSecret) - if err != nil { - return "", "", errors.New("invalid basic auth header") - } - if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil { - return "", "", err - } - return req.Token, clientID, nil - } - return "", "", errors.New("invalid authorization") + + return req.Token, clientID, nil } diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 3d65ea0..b9e9805 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -19,6 +19,7 @@ type Exchanger interface { GrantTypeTokenExchangeSupported() bool GrantTypeJWTAuthorizationSupported() bool GrantTypeClientCredentialsSupported() bool + GrantTypeDeviceCodeSupported() bool AccessTokenVerifier(context.Context) AccessTokenVerifier IDTokenHintVerifier(context.Context) IDTokenHintVerifier } @@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { ClientCredentialsExchange(w, r, exchanger) return } + case string(oidc.GrantTypeDeviceCode): + if exchanger.GrantTypeDeviceCodeSupported() { + DeviceAccessToken(w, r, exchanger) + return + } case "": RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) return