diff --git a/example/client/device/device.go b/example/client/device/device.go
new file mode 100644
index 0000000..284ba37
--- /dev/null
+++ b/example/client/device/device.go
@@ -0,0 +1,61 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/signal"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/sirupsen/logrus"
+
+ "github.com/zitadel/oidc/v2/pkg/client/rp"
+ httphelper "github.com/zitadel/oidc/v2/pkg/http"
+)
+
+var (
+ key = []byte("test1234test1234")
+)
+
+func main() {
+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
+ defer stop()
+
+ clientID := os.Getenv("CLIENT_ID")
+ clientSecret := os.Getenv("CLIENT_SECRET")
+ keyPath := os.Getenv("KEY_PATH")
+ issuer := os.Getenv("ISSUER")
+ scopes := strings.Split(os.Getenv("SCOPES"), " ")
+
+ cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
+
+ var options []rp.Option
+ if clientSecret == "" {
+ options = append(options, rp.WithPKCE(cookieHandler))
+ }
+ if keyPath != "" {
+ options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
+ }
+
+ provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...)
+ if err != nil {
+ logrus.Fatalf("error creating provider %s", err.Error())
+ }
+
+ logrus.Info("starting device authorization flow")
+ resp, err := rp.DeviceAuthorization(scopes, provider)
+ if err != nil {
+ logrus.Fatal(err)
+ }
+ logrus.Info("resp", resp)
+ fmt.Printf("\nPlease browse to %s and enter code %s\n", resp.VerificationURI, resp.UserCode)
+
+ logrus.Info("start polling")
+ token, err := rp.DeviceAccessToken(ctx, resp.DeviceCode, time.Duration(resp.Interval)*time.Second, provider)
+ if err != nil {
+ logrus.Fatal(err)
+ }
+ logrus.Infof("successfully obtained token: %v", token)
+}
diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go
new file mode 100644
index 0000000..ae2e8f2
--- /dev/null
+++ b/example/server/exampleop/device.go
@@ -0,0 +1,191 @@
+package exampleop
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+
+ "github.com/gorilla/mux"
+ "github.com/gorilla/securecookie"
+ "github.com/sirupsen/logrus"
+ "github.com/zitadel/oidc/v2/pkg/op"
+)
+
+type deviceAuthenticate interface {
+ CheckUsernamePasswordSimple(username, password string) error
+ op.DeviceAuthorizationStorage
+}
+
+type deviceLogin struct {
+ storage deviceAuthenticate
+ cookie *securecookie.SecureCookie
+}
+
+func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) {
+ l := &deviceLogin{
+ storage: storage,
+ cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil),
+ }
+
+ router.HandleFunc("", l.userCodeHandler)
+ router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler)
+ router.HandleFunc("/confirm", l.confirmHandler)
+}
+
+func renderUserCode(w io.Writer, err error) {
+ data := struct {
+ Error string
+ }{
+ Error: errMsg(err),
+ }
+
+ if err := templates.ExecuteTemplate(w, "usercode", data); err != nil {
+ logrus.Error(err)
+ }
+}
+
+func renderDeviceLogin(w http.ResponseWriter, userCode string, err error) {
+ data := &struct {
+ UserCode string
+ Error string
+ }{
+ UserCode: userCode,
+ Error: errMsg(err),
+ }
+ if err = templates.ExecuteTemplate(w, "device_login", data); err != nil {
+ logrus.Error(err)
+ }
+}
+
+func renderConfirmPage(w http.ResponseWriter, username, clientID string, scopes []string) {
+ data := &struct {
+ Username string
+ ClientID string
+ Scopes []string
+ }{
+ Username: username,
+ ClientID: clientID,
+ Scopes: scopes,
+ }
+ if err := templates.ExecuteTemplate(w, "confirm_device", data); err != nil {
+ logrus.Error(err)
+ }
+}
+
+func (d *deviceLogin) userCodeHandler(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseForm()
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ renderUserCode(w, err)
+ return
+ }
+ userCode := r.Form.Get("user_code")
+ if userCode == "" {
+ if prompt, _ := url.QueryUnescape(r.Form.Get("prompt")); prompt != "" {
+ err = errors.New(prompt)
+ }
+ renderUserCode(w, err)
+ return
+ }
+
+ renderDeviceLogin(w, userCode, nil)
+}
+
+func redirectBack(w http.ResponseWriter, r *http.Request, prompt string) {
+ values := make(url.Values)
+ values.Set("prompt", url.QueryEscape(prompt))
+
+ url := url.URL{
+ Path: "/device",
+ RawQuery: values.Encode(),
+ }
+ http.Redirect(w, r, url.String(), http.StatusSeeOther)
+}
+
+const userCodeCookieName = "user_code"
+
+type userCodeCookie struct {
+ UserCode string
+ UserName string
+}
+
+func (d *deviceLogin) loginHandler(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+
+ userCode := r.PostForm.Get("user_code")
+ if userCode == "" {
+ redirectBack(w, r, "missing user_code in request")
+ return
+ }
+ username := r.PostForm.Get("username")
+ if username == "" {
+ redirectBack(w, r, "missing username in request")
+ return
+ }
+ password := r.PostForm.Get("password")
+ if password == "" {
+ redirectBack(w, r, "missing password in request")
+ return
+ }
+
+ if err := d.storage.CheckUsernamePasswordSimple(username, password); err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+ state, err := d.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
+ if err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+
+ encoded, err := d.cookie.Encode(userCodeCookieName, userCodeCookie{userCode, username})
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ cookie := &http.Cookie{
+ Name: userCodeCookieName,
+ Value: encoded,
+ Path: "/",
+ }
+ http.SetCookie(w, cookie)
+ renderConfirmPage(w, username, state.ClientID, state.Scopes)
+}
+
+func (d *deviceLogin) confirmHandler(w http.ResponseWriter, r *http.Request) {
+ cookie, err := r.Cookie(userCodeCookieName)
+ if err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+ data := new(userCodeCookie)
+ if err = d.cookie.Decode(userCodeCookieName, cookie.Value, &data); err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+ if err = r.ParseForm(); err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+
+ action := r.Form.Get("action")
+ switch action {
+ case "allowed":
+ err = d.storage.CompleteDeviceAuthorization(r.Context(), data.UserCode, data.UserName)
+ case "denied":
+ err = d.storage.DenyDeviceAuthorization(r.Context(), data.UserCode)
+ default:
+ err = errors.New("action must be one of \"allow\" or \"deny\"")
+ }
+ if err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+
+ fmt.Fprintf(w, "Device authorization %s. You can now return to the device", action)
+}
diff --git a/example/server/exampleop/login.go b/example/server/exampleop/login.go
index 5da86d1..c014c9a 100644
--- a/example/server/exampleop/login.go
+++ b/example/server/exampleop/login.go
@@ -3,45 +3,11 @@ package exampleop
import (
"context"
"fmt"
- "html/template"
"net/http"
"github.com/gorilla/mux"
)
-const (
- queryAuthRequestID = "authRequestID"
-)
-
-var loginTmpl, _ = template.New("login").Parse(`
-
-
-
-
- Login
-
-
-
-
- `)
-
type login struct {
authenticate authenticate
router *mux.Router
@@ -74,23 +40,19 @@ func (l *login) loginHandler(w http.ResponseWriter, r *http.Request) {
return
}
// the oidc package will pass the id of the auth request as query parameter
- // we will use this id through the login process and therefore pass it to the login page
+ // we will use this id through the login process and therefore pass it to the login page
renderLogin(w, r.FormValue(queryAuthRequestID), nil)
}
func renderLogin(w http.ResponseWriter, id string, err error) {
- var errMsg string
- if err != nil {
- errMsg = err.Error()
- }
data := &struct {
ID string
Error string
}{
ID: id,
- Error: errMsg,
+ Error: errMsg(err),
}
- err = loginTmpl.Execute(w, data)
+ err = templates.ExecuteTemplate(w, "login", data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go
index d3a450c..b46be7f 100644
--- a/example/server/exampleop/op.go
+++ b/example/server/exampleop/op.go
@@ -5,6 +5,7 @@ import (
"crypto/sha256"
"log"
"net/http"
+ "time"
"github.com/gorilla/mux"
"golang.org/x/text/language"
@@ -27,7 +28,8 @@ func init() {
type Storage interface {
op.Storage
- CheckUsernamePassword(username, password, id string) error
+ authenticate
+ deviceAuthenticate
}
// SetupServer creates an OIDC server with Issuer=http://localhost:
@@ -62,6 +64,9 @@ func SetupServer(ctx context.Context, issuer string, storage Storage) *mux.Route
// so we will direct all calls to /login to the login UI
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router))
+ router.PathPrefix("/device").Subrouter()
+ registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter())
+
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
// is served on the correct path
//
@@ -99,6 +104,13 @@ func newOP(ctx context.Context, storage op.Storage, issuer string, key [32]byte)
// this example has only static texts (in English), so we'll set the here accordingly
SupportedUILocales: []language.Tag{language.English},
+
+ DeviceAuthorization: op.DeviceAuthorizationConfig{
+ Lifetime: 5 * time.Minute,
+ PollInterval: 5 * time.Second,
+ UserFormURL: issuer + "device",
+ UserCode: op.UserCodeBase20,
+ },
}
handler, err := op.NewOpenIDProvider(ctx, issuer, config, storage,
//we must explicitly allow the use of the http issuer
diff --git a/example/server/exampleop/templates.go b/example/server/exampleop/templates.go
new file mode 100644
index 0000000..5b5c966
--- /dev/null
+++ b/example/server/exampleop/templates.go
@@ -0,0 +1,26 @@
+package exampleop
+
+import (
+ "embed"
+ "html/template"
+
+ "github.com/sirupsen/logrus"
+)
+
+var (
+ //go:embed templates
+ templateFS embed.FS
+ templates = template.Must(template.ParseFS(templateFS, "templates/*.html"))
+)
+
+const (
+ queryAuthRequestID = "authRequestID"
+)
+
+func errMsg(err error) string {
+ if err == nil {
+ return ""
+ }
+ logrus.Error(err)
+ return err.Error()
+}
diff --git a/example/server/exampleop/templates/confirm_device.html b/example/server/exampleop/templates/confirm_device.html
new file mode 100644
index 0000000..a6bcdad
--- /dev/null
+++ b/example/server/exampleop/templates/confirm_device.html
@@ -0,0 +1,25 @@
+{{ define "confirm_device" -}}
+
+
+
+
+ Confirm device authorization
+
+
+
+ Welcome back {{.Username}}!
+
+ You are about to grant device {{.ClientID}} access to the following scopes: {{.Scopes}}.
+
+
+
+
+
+{{- end }}
diff --git a/example/server/exampleop/templates/device_login.html b/example/server/exampleop/templates/device_login.html
new file mode 100644
index 0000000..cc5b00b
--- /dev/null
+++ b/example/server/exampleop/templates/device_login.html
@@ -0,0 +1,29 @@
+{{ define "device_login" -}}
+
+
+
+
+ Login
+
+
+
+
+
+{{- end }}
diff --git a/example/server/exampleop/templates/login.html b/example/server/exampleop/templates/login.html
new file mode 100644
index 0000000..b048211
--- /dev/null
+++ b/example/server/exampleop/templates/login.html
@@ -0,0 +1,29 @@
+{{ define "login" -}}
+
+
+
+
+ Login
+
+
+
+
+`
+{{- end }}
\ No newline at end of file
diff --git a/example/server/exampleop/templates/usercode.html b/example/server/exampleop/templates/usercode.html
new file mode 100644
index 0000000..fb8fa7f
--- /dev/null
+++ b/example/server/exampleop/templates/usercode.html
@@ -0,0 +1,21 @@
+{{ define "usercode" -}}
+
+
+
+
+ Device authorization
+
+
+
+
+
+{{- end }}
diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go
index 662132c..b49ce1b 100644
--- a/example/server/storage/storage.go
+++ b/example/server/storage/storage.go
@@ -44,6 +44,8 @@ type Storage struct {
services map[string]Service
refreshTokens map[string]*RefreshToken
signingKey signingKey
+ deviceCodes map[string]deviceAuthorizationEntry
+ userCodes map[string]string
}
type signingKey struct {
@@ -105,6 +107,8 @@ func NewStorage(userStore UserStore) *Storage {
algorithm: jose.RS256,
key: key,
},
+ deviceCodes: make(map[string]deviceAuthorizationEntry),
+ userCodes: make(map[string]string),
}
}
@@ -135,6 +139,17 @@ func (s *Storage) CheckUsernamePassword(username, password, id string) error {
return fmt.Errorf("username or password wrong")
}
+func (s *Storage) CheckUsernamePasswordSimple(username, password string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ user := s.userStore.GetUserByUsername(username)
+ if user != nil && user.Password == password {
+ return nil
+ }
+ return fmt.Errorf("username or password wrong")
+}
+
// CreateAuthRequest implements the op.Storage interface
// it will be called after parsing and validation of the authentication request
func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
@@ -735,3 +750,85 @@ func appendClaim(claims map[string]interface{}, claim string, value interface{})
claims[claim] = value
return claims
}
+
+type deviceAuthorizationEntry struct {
+ deviceCode string
+ userCode string
+ state *op.DeviceAuthorizationState
+}
+
+func (s *Storage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if _, ok := s.clients[clientID]; !ok {
+ return errors.New("client not found")
+ }
+
+ if _, ok := s.userCodes[userCode]; ok {
+ return op.ErrDuplicateUserCode
+ }
+
+ s.deviceCodes[deviceCode] = deviceAuthorizationEntry{
+ deviceCode: deviceCode,
+ userCode: userCode,
+ state: &op.DeviceAuthorizationState{
+ ClientID: clientID,
+ Scopes: scopes,
+ Expires: expires,
+ },
+ }
+
+ s.userCodes[userCode] = deviceCode
+ return nil
+}
+
+func (s *Storage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) {
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ entry, ok := s.deviceCodes[deviceCode]
+ if !ok || entry.state.ClientID != clientID {
+ return nil, errors.New("device code not found for client") // is there a standard not found error in the framework?
+ }
+
+ return entry.state, nil
+}
+
+func (s *Storage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ entry, ok := s.deviceCodes[s.userCodes[userCode]]
+ if !ok {
+ return nil, errors.New("user code not found")
+ }
+
+ return entry.state, nil
+}
+
+func (s *Storage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ entry, ok := s.deviceCodes[s.userCodes[userCode]]
+ if !ok {
+ return errors.New("user code not found")
+ }
+
+ entry.state.Subject = subject
+ entry.state.Done = true
+ return nil
+}
+
+func (s *Storage) DenyDeviceAuthorization(ctx context.Context, userCode string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ s.deviceCodes[s.userCodes[userCode]].state.Denied = true
+ return nil
+}
diff --git a/pkg/client/client.go b/pkg/client/client.go
index 077baf2..b9ae008 100644
--- a/pkg/client/client.go
+++ b/pkg/client/client.go
@@ -1,6 +1,8 @@
package client
import (
+ "context"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -186,3 +188,94 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
IssuedAt: oidc.Time(iat),
}, signer)
}
+
+type DeviceAuthorizationCaller interface {
+ GetDeviceAuthorizationEndpoint() string
+ HttpClient() *http.Client
+}
+
+func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
+ req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
+ if err != nil {
+ return nil, err
+ }
+ if request.ClientSecret != "" {
+ req.SetBasicAuth(request.ClientID, request.ClientSecret)
+ }
+
+ resp := new(oidc.DeviceAuthorizationResponse)
+ if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+type DeviceAccessTokenRequest struct {
+ *oidc.ClientCredentialsRequest
+ oidc.DeviceAccessTokenRequest
+}
+
+func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
+ req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
+ if err != nil {
+ return nil, err
+ }
+ if request.ClientSecret != "" {
+ req.SetBasicAuth(request.ClientID, request.ClientSecret)
+ }
+
+ httpResp, err := caller.HttpClient().Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer httpResp.Body.Close()
+
+ resp := new(struct {
+ *oidc.AccessTokenResponse
+ *oidc.Error
+ })
+ if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil {
+ return nil, err
+ }
+
+ if httpResp.StatusCode == http.StatusOK {
+ return resp.AccessTokenResponse, nil
+ }
+
+ return nil, resp.Error
+}
+
+func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
+ for {
+ timer := time.After(interval)
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timer:
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, interval)
+ defer cancel()
+
+ resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
+ if err == nil {
+ return resp, nil
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ interval += 5 * time.Second
+ }
+ var target *oidc.Error
+ if !errors.As(err, &target) {
+ return nil, err
+ }
+ switch target.ErrorType {
+ case oidc.AuthorizationPending:
+ continue
+ case oidc.SlowDown:
+ interval += 5 * time.Second
+ continue
+ default:
+ return nil, err
+ }
+ }
+}
diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go
new file mode 100644
index 0000000..73b67ca
--- /dev/null
+++ b/pkg/client/rp/device.go
@@ -0,0 +1,62 @@
+package rp
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/zitadel/oidc/v2/pkg/client"
+ "github.com/zitadel/oidc/v2/pkg/oidc"
+)
+
+func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
+ confg := rp.OAuthConfig()
+ req := &oidc.ClientCredentialsRequest{
+ GrantType: oidc.GrantTypeDeviceCode,
+ Scope: scopes,
+ ClientID: confg.ClientID,
+ ClientSecret: confg.ClientSecret,
+ }
+
+ if signer := rp.Signer(); signer != nil {
+ assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
+ if err != nil {
+ return nil, fmt.Errorf("failed to build assertion: %w", err)
+ }
+ req.ClientAssertion = assertion
+ req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
+ }
+
+ return req, nil
+}
+
+// DeviceAuthorization starts a new Device Authorization flow as defined
+// in RFC 8628, section 3.1 and 3.2:
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
+func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
+ req, err := newDeviceClientCredentialsRequest(scopes, rp)
+ if err != nil {
+ return nil, err
+ }
+
+ return client.CallDeviceAuthorizationEndpoint(req, rp)
+}
+
+// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
+// by means of polling as defined in RFC, section 3.3 and 3.4:
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
+func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
+ req := &client.DeviceAccessTokenRequest{
+ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
+ GrantType: oidc.GrantTypeDeviceCode,
+ DeviceCode: deviceCode,
+ },
+ }
+
+ req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
+ if err != nil {
+ return nil, err
+ }
+
+ return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
+}
diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go
index d2e3cf7..96fe219 100644
--- a/pkg/client/rp/relying_party.go
+++ b/pkg/client/rp/relying_party.go
@@ -59,6 +59,10 @@ type RelyingParty interface {
// UserinfoEndpoint returns the userinfo
UserinfoEndpoint() string
+ // GetDeviceAuthorizationEndpoint returns the enpoint which can
+ // be used to start a DeviceAuthorization flow.
+ GetDeviceAuthorizationEndpoint() string
+
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier
// ErrorHandler returns the handler used for callback errors
@@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
return rp.endpoints.UserinfoURL
}
+func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
+ return rp.endpoints.DeviceAuthorizationURL
+}
+
func (rp *relyingParty) GetEndSessionEndpoint() string {
return rp.endpoints.EndSessionURL
}
@@ -495,11 +503,12 @@ type OptionFunc func(RelyingParty)
type Endpoints struct {
oauth2.Endpoint
- IntrospectURL string
- UserinfoURL string
- JKWsURL string
- EndSessionURL string
- RevokeURL string
+ IntrospectURL string
+ UserinfoURL string
+ JKWsURL string
+ EndSessionURL string
+ RevokeURL string
+ DeviceAuthorizationURL string
}
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
@@ -509,11 +518,12 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
AuthStyle: oauth2.AuthStyleAutoDetect,
TokenURL: discoveryConfig.TokenEndpoint,
},
- IntrospectURL: discoveryConfig.IntrospectionEndpoint,
- UserinfoURL: discoveryConfig.UserinfoEndpoint,
- JKWsURL: discoveryConfig.JwksURI,
- EndSessionURL: discoveryConfig.EndSessionEndpoint,
- RevokeURL: discoveryConfig.RevocationEndpoint,
+ IntrospectURL: discoveryConfig.IntrospectionEndpoint,
+ UserinfoURL: discoveryConfig.UserinfoEndpoint,
+ JKWsURL: discoveryConfig.JwksURI,
+ EndSessionURL: discoveryConfig.EndSessionEndpoint,
+ RevokeURL: discoveryConfig.RevocationEndpoint,
+ DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
}
}
diff --git a/pkg/oidc/device_authorization.go b/pkg/oidc/device_authorization.go
new file mode 100644
index 0000000..68b8efa
--- /dev/null
+++ b/pkg/oidc/device_authorization.go
@@ -0,0 +1,29 @@
+package oidc
+
+// DeviceAuthorizationRequest implements
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.1,
+// 3.1 Device Authorization Request.
+type DeviceAuthorizationRequest struct {
+ Scopes SpaceDelimitedArray `schema:"scope"`
+ ClientID string `schema:"client_id"`
+}
+
+// DeviceAuthorizationResponse implements
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.2
+// 3.2. Device Authorization Response.
+type DeviceAuthorizationResponse struct {
+ DeviceCode string `json:"device_code"`
+ UserCode string `json:"user_code"`
+ VerificationURI string `json:"verification_uri"`
+ VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
+ ExpiresIn int `json:"expires_in"`
+ Interval int `json:"interval,omitempty"`
+}
+
+// DeviceAccessTokenRequest implements
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.4,
+// Device Access Token Request.
+type DeviceAccessTokenRequest struct {
+ GrantType GrantType `json:"grant_type" schema:"grant_type"`
+ DeviceCode string `json:"device_code" schema:"device_code"`
+}
diff --git a/pkg/oidc/discovery.go b/pkg/oidc/discovery.go
index fbc417b..3574101 100644
--- a/pkg/oidc/discovery.go
+++ b/pkg/oidc/discovery.go
@@ -30,6 +30,8 @@ type DiscoveryConfiguration struct {
// EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
+ DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"`
+
// CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
CheckSessionIframe string `json:"check_session_iframe,omitempty"`
diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go
index 5797a59..79acecd 100644
--- a/pkg/oidc/error.go
+++ b/pkg/oidc/error.go
@@ -18,6 +18,14 @@ const (
InteractionRequired errorType = "interaction_required"
LoginRequired errorType = "login_required"
RequestNotSupported errorType = "request_not_supported"
+
+ // Additional error codes as defined in
+ // https://www.rfc-editor.org/rfc/rfc8628#section-3.5
+ // Device Access Token Response
+ AuthorizationPending errorType = "authorization_pending"
+ SlowDown errorType = "slow_down"
+ AccessDenied errorType = "access_denied"
+ ExpiredToken errorType = "expired_token"
)
var (
@@ -77,6 +85,32 @@ var (
ErrorType: RequestNotSupported,
}
}
+
+ // Device Access Token errors:
+ ErrAuthorizationPending = func() *Error {
+ return &Error{
+ ErrorType: AuthorizationPending,
+ Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.",
+ }
+ }
+ ErrSlowDown = func() *Error {
+ return &Error{
+ ErrorType: SlowDown,
+ Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.",
+ }
+ }
+ ErrAccessDenied = func() *Error {
+ return &Error{
+ ErrorType: AccessDenied,
+ Description: "The authorization request was denied.",
+ }
+ }
+ ErrExpiredDeviceCode = func() *Error {
+ return &Error{
+ ErrorType: ExpiredToken,
+ Description: "The \"device_code\" has expired.",
+ }
+ }
)
type Error struct {
diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go
index 6d8f186..78bd658 100644
--- a/pkg/oidc/token_request.go
+++ b/pkg/oidc/token_request.go
@@ -27,6 +27,9 @@ const (
// GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code
GrantTypeImplicit GrantType = "implicit"
+ // GrantTypeDeviceCode
+ GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code"
+
// ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
// used for the OAuth JWT Profile Client Authentication
ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
@@ -35,7 +38,7 @@ const (
var AllGrantTypes = []GrantType{
GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials,
GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit,
- ClientAssertionTypeJWTAssertion,
+ GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion,
}
type GrantType string
diff --git a/pkg/op/client.go b/pkg/op/client.go
index e8a3347..1f5e1c9 100644
--- a/pkg/op/client.go
+++ b/pkg/op/client.go
@@ -1,8 +1,13 @@
package op
import (
+ "context"
+ "errors"
+ "net/http"
+ "net/url"
"time"
+ httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
@@ -57,3 +62,95 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT
func IsConfidentialType(c Client) bool {
return c.ApplicationType() == ApplicationTypeWeb
}
+
+var (
+ ErrInvalidAuthHeader = errors.New("invalid basic auth header")
+ ErrNoClientCredentials = errors.New("no client credentials provided")
+ ErrMissingClientID = errors.New("client_id missing from request")
+)
+
+type ClientJWTProfile interface {
+ JWTProfileVerifier(context.Context) JWTProfileVerifier
+}
+
+func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
+ if ca.ClientAssertion == "" {
+ return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
+ }
+
+ profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx))
+ if err != nil {
+ return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed")
+ }
+ return profile.Issuer, nil
+}
+
+func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) {
+ clientID, clientSecret, ok := r.BasicAuth()
+ if !ok {
+ return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
+ }
+ clientID, err = url.QueryUnescape(clientID)
+ if err != nil {
+ return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
+ }
+ clientSecret, err = url.QueryUnescape(clientSecret)
+ if err != nil {
+ return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
+ }
+ if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
+ return "", oidc.ErrUnauthorizedClient().WithParent(err)
+ }
+ return clientID, nil
+}
+
+type ClientProvider interface {
+ Decoder() httphelper.Decoder
+ Storage() Storage
+}
+
+type clientData struct {
+ ClientID string `schema:"client_id"`
+ oidc.ClientAssertionParams
+}
+
+// ClientIDFromRequest parses the request form and tries to obtain the client ID
+// and reports if it is authenticated, using a JWT or static client secrets over
+// http basic auth.
+//
+// If the Provider implements IntrospectorJWTProfile and "client_assertion" is
+// present in the form data, JWT assertion will be verified and the
+// client ID is taken from there.
+// If any of them is absent, basic auth is attempted.
+// In absence of basic auth data, the unauthenticated client id from the form
+// data is returned.
+//
+// If no client id can be obtained by any method, oidc.ErrInvalidClient
+// is returned with ErrMissingClientID wrapped in it.
+func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) {
+ err = r.ParseForm()
+ if err != nil {
+ return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
+ }
+
+ data := new(clientData)
+ if err = p.Decoder().Decode(data, r.PostForm); err != nil {
+ return "", false, err
+ }
+
+ JWTProfile, ok := p.(ClientJWTProfile)
+ if ok {
+ clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
+ }
+ if !ok || errors.Is(err, ErrNoClientCredentials) {
+ clientID, err = ClientBasicAuth(r, p.Storage())
+ }
+ if err == nil {
+ return clientID, true, nil
+ }
+
+ if data.ClientID == "" {
+ return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
+ }
+ return data.ClientID, false, nil
+}
diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go
new file mode 100644
index 0000000..1af4157
--- /dev/null
+++ b/pkg/op/client_test.go
@@ -0,0 +1,253 @@
+package op_test
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/golang/mock/gomock"
+ "github.com/gorilla/schema"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ httphelper "github.com/zitadel/oidc/v2/pkg/http"
+ "github.com/zitadel/oidc/v2/pkg/oidc"
+ "github.com/zitadel/oidc/v2/pkg/op"
+ "github.com/zitadel/oidc/v2/pkg/op/mock"
+)
+
+type testClientJWTProfile struct{}
+
+func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
+
+func TestClientJWTAuth(t *testing.T) {
+ type args struct {
+ ctx context.Context
+ ca oidc.ClientAssertionParams
+ verifier op.ClientJWTProfile
+ }
+ tests := []struct {
+ name string
+ args args
+ wantClientID string
+ wantErr error
+ }{
+ {
+ name: "empty assertion",
+ args: args{
+ context.Background(),
+ oidc.ClientAssertionParams{},
+ testClientJWTProfile{},
+ },
+ wantErr: op.ErrNoClientCredentials,
+ },
+ {
+ name: "verification error",
+ args: args{
+ context.Background(),
+ oidc.ClientAssertionParams{
+ ClientAssertion: "foo",
+ },
+ testClientJWTProfile{},
+ },
+ wantErr: oidc.ErrParse,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.wantClientID, gotClientID)
+ })
+ }
+}
+
+func TestClientBasicAuth(t *testing.T) {
+ errWrong := errors.New("wrong secret")
+
+ type args struct {
+ username string
+ password string
+ }
+ tests := []struct {
+ name string
+ args *args
+ storage op.Storage
+ wantClientID string
+ wantErr error
+ }{
+ {
+ name: "no args",
+ wantErr: op.ErrNoClientCredentials,
+ },
+ {
+ name: "username unescape err",
+ args: &args{
+ username: "%",
+ password: "bar",
+ },
+ wantErr: op.ErrInvalidAuthHeader,
+ },
+ {
+ name: "password unescape err",
+ args: &args{
+ username: "foo",
+ password: "%",
+ },
+ wantErr: op.ErrInvalidAuthHeader,
+ },
+ {
+ name: "auth error",
+ args: &args{
+ username: "foo",
+ password: "wrong",
+ },
+ storage: func() op.Storage {
+ s := mock.NewMockStorage(gomock.NewController(t))
+ s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong)
+ return s
+ }(),
+ wantErr: errWrong,
+ },
+ {
+ name: "auth error",
+ args: &args{
+ username: "foo",
+ password: "bar",
+ },
+ storage: func() op.Storage {
+ s := mock.NewMockStorage(gomock.NewController(t))
+ s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
+ return s
+ }(),
+ wantClientID: "foo",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := httptest.NewRequest(http.MethodGet, "/foo", nil)
+ if tt.args != nil {
+ r.SetBasicAuth(tt.args.username, tt.args.password)
+ }
+
+ gotClientID, err := op.ClientBasicAuth(r, tt.storage)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.wantClientID, gotClientID)
+ })
+ }
+}
+
+type errReader struct{}
+
+func (errReader) Read([]byte) (int, error) {
+ return 0, io.ErrNoProgress
+}
+
+type testClientProvider struct {
+ storage op.Storage
+}
+
+func (testClientProvider) Decoder() httphelper.Decoder {
+ return schema.NewDecoder()
+}
+
+func (p testClientProvider) Storage() op.Storage {
+ return p.storage
+}
+
+func TestClientIDFromRequest(t *testing.T) {
+ type args struct {
+ body io.Reader
+ p op.ClientProvider
+ }
+ type basicAuth struct {
+ username string
+ password string
+ }
+ tests := []struct {
+ name string
+ args args
+ basicAuth *basicAuth
+ wantClientID string
+ wantAuthenticated bool
+ wantErr bool
+ }{
+ {
+ name: "parse error",
+ args: args{
+ body: errReader{},
+ },
+ wantErr: true,
+ },
+ {
+ name: "unauthenticated",
+ args: args{
+ body: strings.NewReader(
+ url.Values{
+ "client_id": []string{"foo"},
+ }.Encode(),
+ ),
+ p: testClientProvider{
+ storage: mock.NewStorage(t),
+ },
+ },
+ wantClientID: "foo",
+ wantAuthenticated: false,
+ },
+ {
+ name: "authenticated",
+ args: args{
+ body: strings.NewReader(
+ url.Values{}.Encode(),
+ ),
+ p: testClientProvider{
+ storage: func() op.Storage {
+ s := mock.NewMockStorage(gomock.NewController(t))
+ s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
+ return s
+ }(),
+ },
+ },
+ basicAuth: &basicAuth{
+ username: "foo",
+ password: "bar",
+ },
+ wantClientID: "foo",
+ wantAuthenticated: true,
+ },
+ {
+ name: "missing client id",
+ args: args{
+ body: strings.NewReader(
+ url.Values{}.Encode(),
+ ),
+ p: testClientProvider{
+ storage: mock.NewStorage(t),
+ },
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ if tt.basicAuth != nil {
+ r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
+ }
+
+ gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p)
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.wantClientID, gotClientID)
+ assert.Equal(t, tt.wantAuthenticated, gotAuthenticated)
+ })
+ }
+}
diff --git a/pkg/op/config.go b/pkg/op/config.go
index c40fa2d..c40ed39 100644
--- a/pkg/op/config.go
+++ b/pkg/op/config.go
@@ -27,6 +27,7 @@ type Configuration interface {
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint
+ DeviceAuthorizationEndpoint() Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool
@@ -36,6 +37,7 @@ type Configuration interface {
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool
+ GrantTypeDeviceCodeSupported() bool
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
IntrospectionEndpointSigningAlgorithmsSupported() []string
RevocationAuthMethodPrivateKeyJWTSupported() bool
@@ -44,6 +46,7 @@ type Configuration interface {
RequestObjectSigningAlgorithmsSupported() []string
SupportedUILocales() []language.Tag
+ DeviceAuthorization() DeviceAuthorizationConfig
}
type IssuerFromRequest func(r *http.Request) string
diff --git a/pkg/op/device.go b/pkg/op/device.go
new file mode 100644
index 0000000..04c06f2
--- /dev/null
+++ b/pkg/op/device.go
@@ -0,0 +1,265 @@
+package op
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "math/big"
+ "net/http"
+ "strings"
+ "time"
+
+ httphelper "github.com/zitadel/oidc/v2/pkg/http"
+ "github.com/zitadel/oidc/v2/pkg/oidc"
+)
+
+type DeviceAuthorizationConfig struct {
+ Lifetime time.Duration
+ PollInterval time.Duration
+ UserFormURL string // the URL where the user must go to authorize the device
+ UserCode UserCodeConfig
+}
+
+type UserCodeConfig struct {
+ CharSet string
+ CharAmount int
+ DashInterval int
+}
+
+const (
+ CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
+ CharSetDigits = "0123456789"
+)
+
+var (
+ UserCodeBase20 = UserCodeConfig{
+ CharSet: CharSetBase20,
+ CharAmount: 8,
+ DashInterval: 4,
+ }
+ UserCodeDigits = UserCodeConfig{
+ CharSet: CharSetDigits,
+ CharAmount: 9,
+ DashInterval: 3,
+ }
+)
+
+func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if err := DeviceAuthorization(w, r, o); err != nil {
+ RequestError(w, r, err)
+ }
+ }
+}
+
+func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
+ storage, err := assertDeviceStorage(o.Storage())
+ if err != nil {
+ return err
+ }
+
+ req, err := ParseDeviceCodeRequest(r, o)
+ if err != nil {
+ return err
+ }
+
+ config := o.DeviceAuthorization()
+
+ deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
+ if err != nil {
+ return err
+ }
+ userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
+ if err != nil {
+ return err
+ }
+
+ expires := time.Now().Add(config.Lifetime)
+ err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
+ if err != nil {
+ return err
+ }
+
+ response := &oidc.DeviceAuthorizationResponse{
+ DeviceCode: deviceCode,
+ UserCode: userCode,
+ VerificationURI: config.UserFormURL,
+ ExpiresIn: int(config.Lifetime / time.Second),
+ Interval: int(config.PollInterval / time.Second),
+ }
+
+ response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
+
+ httphelper.MarshalJSON(w, response)
+ return nil
+}
+
+func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
+ clientID, _, err := ClientIDFromRequest(r, o)
+ if err != nil {
+ return nil, err
+ }
+
+ req := new(oidc.DeviceAuthorizationRequest)
+ if err := o.Decoder().Decode(req, r.Form); err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
+ }
+ req.ClientID = clientID
+
+ return req, nil
+}
+
+// 16 bytes gives 128 bit of entropy.
+// results in a 22 character base64 encoded string.
+const RecommendedDeviceCodeBytes = 16
+
+func NewDeviceCode(nBytes int) (string, error) {
+ bytes := make([]byte, nBytes)
+ if _, err := rand.Read(bytes); err != nil {
+ return "", fmt.Errorf("%w getting entropy for device code", err)
+ }
+ return base64.RawURLEncoding.EncodeToString(bytes), nil
+}
+
+func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
+ var buf strings.Builder
+ if dashInterval > 0 {
+ buf.Grow(charAmount + charAmount/dashInterval - 1)
+ } else {
+ buf.Grow(charAmount)
+ }
+
+ max := big.NewInt(int64(len(charSet)))
+
+ for i := 0; i < charAmount; i++ {
+ if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
+ buf.WriteByte('-')
+ }
+
+ bi, err := rand.Int(rand.Reader, max)
+ if err != nil {
+ return "", fmt.Errorf("%w getting entropy for user code", err)
+ }
+
+ buf.WriteRune(charSet[int(bi.Int64())])
+ }
+
+ return buf.String(), nil
+}
+
+type deviceAccessTokenRequest struct {
+ subject string
+ audience []string
+ scopes []string
+}
+
+func (r *deviceAccessTokenRequest) GetSubject() string {
+ return r.subject
+}
+
+func (r *deviceAccessTokenRequest) GetAudience() []string {
+ return r.audience
+}
+
+func (r *deviceAccessTokenRequest) GetScopes() []string {
+ return r.scopes
+}
+
+func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
+ if err := deviceAccessToken(w, r, exchanger); err != nil {
+ RequestError(w, r, err)
+ }
+}
+
+func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
+ // use a limited context timeout shorter as the default
+ // poll interval of 5 seconds.
+ ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
+ defer cancel()
+ r = r.WithContext(ctx)
+
+ clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger)
+ if err != nil {
+ return err
+ }
+
+ req, err := ParseDeviceAccessTokenRequest(r, exchanger)
+ if err != nil {
+ return err
+ }
+ state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
+ if err != nil {
+ return err
+ }
+
+ client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
+ if err != nil {
+ return err
+ }
+ if clientAuthenticated != IsConfidentialType(client) {
+ return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
+ WithDescription("confidential client requires authentication")
+ }
+
+ tokenRequest := &deviceAccessTokenRequest{
+ subject: state.Subject,
+ audience: []string{clientID},
+ scopes: state.Scopes,
+ }
+ resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
+ if err != nil {
+ return err
+ }
+
+ httphelper.MarshalJSON(w, resp)
+ return nil
+}
+
+func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
+ req := new(oidc.DeviceAccessTokenRequest)
+ if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
+ return nil, err
+ }
+ return req, nil
+}
+
+func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
+ storage, err := assertDeviceStorage(exchanger.Storage())
+ if err != nil {
+ return nil, err
+ }
+
+ state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
+ if errors.Is(err, context.DeadlineExceeded) {
+ return nil, oidc.ErrSlowDown().WithParent(err)
+ }
+ if err != nil {
+ return nil, oidc.ErrAccessDenied().WithParent(err)
+ }
+ if state.Denied {
+ return state, oidc.ErrAccessDenied()
+ }
+ if state.Done {
+ return state, nil
+ }
+ if time.Now().After(state.Expires) {
+ return state, oidc.ErrExpiredDeviceCode()
+ }
+ return state, oidc.ErrAuthorizationPending()
+}
+
+func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) {
+ accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, client, "")
+ if err != nil {
+ return nil, err
+ }
+
+ return &oidc.AccessTokenResponse{
+ AccessToken: accessToken,
+ RefreshToken: refreshToken,
+ TokenType: oidc.BearerToken,
+ ExpiresIn: uint64(validity.Seconds()),
+ }, nil
+}
diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go
new file mode 100644
index 0000000..ca68759
--- /dev/null
+++ b/pkg/op/device_test.go
@@ -0,0 +1,453 @@
+package op_test
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "io"
+ mr "math/rand"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/zitadel/oidc/v2/example/server/storage"
+ "github.com/zitadel/oidc/v2/pkg/oidc"
+ "github.com/zitadel/oidc/v2/pkg/op"
+ "golang.org/x/text/language"
+)
+
+var testProvider op.OpenIDProvider
+
+const (
+ testIssuer = "https://localhost:9998/"
+ pathLoggedOut = "/logged-out"
+)
+
+func init() {
+ config := &op.Config{
+ CryptoKey: sha256.Sum256([]byte("test")),
+ DefaultLogoutRedirectURI: pathLoggedOut,
+ CodeMethodS256: true,
+ AuthMethodPost: true,
+ AuthMethodPrivateKeyJWT: true,
+ GrantTypeRefreshToken: true,
+ RequestObjectSupported: true,
+ SupportedUILocales: []language.Tag{language.English},
+ DeviceAuthorization: op.DeviceAuthorizationConfig{
+ Lifetime: 5 * time.Minute,
+ PollInterval: 5 * time.Second,
+ UserFormURL: testIssuer + "device",
+ UserCode: op.UserCodeBase20,
+ },
+ }
+
+ storage.RegisterClients(
+ storage.NativeClient("native"),
+ storage.WebClient("web", "secret"),
+ storage.WebClient("api", "secret"),
+ )
+
+ var err error
+ testProvider, err = op.NewOpenIDProvider(context.TODO(), testIssuer, config,
+ storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
+ )
+ if err != nil {
+ panic(err)
+ }
+}
+
+func Test_deviceAuthorizationHandler(t *testing.T) {
+ req := &oidc.DeviceAuthorizationRequest{
+ Scopes: []string{"foo", "bar"},
+ ClientID: "web",
+ }
+ values := make(url.Values)
+ testProvider.Encoder().Encode(req, values)
+ body := strings.NewReader(values.Encode())
+
+ r := httptest.NewRequest(http.MethodPost, "/", body)
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ w := httptest.NewRecorder()
+
+ runWithRandReader(mr.New(mr.NewSource(1)), func() {
+ op.DeviceAuthorizationHandler(testProvider)(w, r)
+ })
+
+ result := w.Result()
+
+ assert.Less(t, result.StatusCode, 300)
+
+ got, _ := io.ReadAll(result.Body)
+ assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
+}
+
+func TestParseDeviceCodeRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ req *oidc.DeviceAuthorizationRequest
+ wantErr bool
+ }{
+ {
+ name: "empty request",
+ wantErr: true,
+ },
+ /* decoding a SpaceDelimitedArray is broken
+ https://github.com/zitadel/oidc/issues/295
+ {
+ name: "success",
+ req: &oidc.DeviceAuthorizationRequest{
+ Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
+ ClientID: "web",
+ },
+ },
+ */
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var body io.Reader
+ if tt.req != nil {
+ values := make(url.Values)
+ testProvider.Encoder().Encode(tt.req, values)
+ body = strings.NewReader(values.Encode())
+ }
+
+ r := httptest.NewRequest(http.MethodPost, "/", body)
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ got, err := op.ParseDeviceCodeRequest(r, testProvider)
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.req, got)
+ })
+ }
+}
+
+func runWithRandReader(r io.Reader, f func()) {
+ originalReader := rand.Reader
+ rand.Reader = r
+ defer func() {
+ rand.Reader = originalReader
+ }()
+
+ f()
+}
+
+func TestNewDeviceCode(t *testing.T) {
+ t.Run("reader error", func(t *testing.T) {
+ runWithRandReader(errReader{}, func() {
+ _, err := op.NewDeviceCode(16)
+ require.Error(t, err)
+ })
+ })
+
+ t.Run("different lengths, rand reader", func(t *testing.T) {
+ for i := 1; i <= 32; i++ {
+ got, err := op.NewDeviceCode(i)
+ require.NoError(t, err)
+ assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
+ }
+ })
+
+}
+
+func TestNewUserCode(t *testing.T) {
+ type args struct {
+ charset []rune
+ charAmount int
+ dashInterval int
+ }
+ tests := []struct {
+ name string
+ args args
+ reader io.Reader
+ want string
+ wantErr bool
+ }{
+ {
+ name: "reader error",
+ args: args{
+ charset: []rune(op.CharSetBase20),
+ charAmount: 8,
+ dashInterval: 4,
+ },
+ reader: errReader{},
+ wantErr: true,
+ },
+ {
+ name: "base20",
+ args: args{
+ charset: []rune(op.CharSetBase20),
+ charAmount: 8,
+ dashInterval: 4,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ want: "XKCD-HTTD",
+ },
+ {
+ name: "digits",
+ args: args{
+ charset: []rune(op.CharSetDigits),
+ charAmount: 9,
+ dashInterval: 3,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ want: "271-256-225",
+ },
+ {
+ name: "no dashes",
+ args: args{
+ charset: []rune(op.CharSetDigits),
+ charAmount: 9,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ want: "271256225",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ runWithRandReader(tt.reader, func() {
+ got, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
+ if tt.wantErr {
+ require.ErrorIs(t, err, io.ErrNoProgress)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.want, got)
+ })
+
+ })
+ }
+
+ t.Run("crypto/rand", func(t *testing.T) {
+ const testN = 100000
+
+ for _, c := range []op.UserCodeConfig{op.UserCodeBase20, op.UserCodeDigits} {
+ t.Run(c.CharSet, func(t *testing.T) {
+ results := make(map[string]int)
+
+ for i := 0; i < testN; i++ {
+ code, err := op.NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval)
+ require.NoError(t, err)
+ results[code]++
+ }
+
+ t.Log(results)
+
+ var duplicates int
+ for code, count := range results {
+ assert.Less(t, count, 3, code)
+ if count == 2 {
+ duplicates++
+ }
+ }
+
+ })
+ }
+ })
+}
+
+func BenchmarkNewUserCode(b *testing.B) {
+ type args struct {
+ charset []rune
+ charAmount int
+ dashInterval int
+ }
+ tests := []struct {
+ name string
+ args args
+ reader io.Reader
+ }{
+ {
+ name: "math rand, base20",
+ args: args{
+ charset: []rune(op.CharSetBase20),
+ charAmount: 8,
+ dashInterval: 4,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ },
+ {
+ name: "math rand, digits",
+ args: args{
+ charset: []rune(op.CharSetDigits),
+ charAmount: 9,
+ dashInterval: 3,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ },
+ {
+ name: "crypto rand, base20",
+ args: args{
+ charset: []rune(op.CharSetBase20),
+ charAmount: 8,
+ dashInterval: 4,
+ },
+ reader: rand.Reader,
+ },
+ {
+ name: "crypto rand, digits",
+ args: args{
+ charset: []rune(op.CharSetDigits),
+ charAmount: 9,
+ dashInterval: 3,
+ },
+ reader: rand.Reader,
+ },
+ }
+ for _, tt := range tests {
+ runWithRandReader(tt.reader, func() {
+ b.Run(tt.name, func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
+ require.NoError(b, err)
+ }
+ })
+
+ })
+ }
+}
+
+func TestDeviceAccessToken(t *testing.T) {
+ storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
+ storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
+ storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
+
+ values := make(url.Values)
+ values.Set("client_id", "native")
+ values.Set("grant_type", string(oidc.GrantTypeDeviceCode))
+ values.Set("device_code", "qwerty")
+
+ r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode()))
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ w := httptest.NewRecorder()
+
+ op.DeviceAccessToken(w, r, testProvider)
+
+ result := w.Result()
+ got, _ := io.ReadAll(result.Body)
+ t.Log(string(got))
+ assert.Less(t, result.StatusCode, 300)
+ assert.NotEmpty(t, string(got))
+}
+
+func TestCheckDeviceAuthorizationState(t *testing.T) {
+ now := time.Now()
+
+ storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
+ storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
+ storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
+ storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})
+ storage.StoreDeviceAuthorization(context.Background(), "native", "expired", "expired", now.Add(-time.Minute), []string{"foo"})
+
+ storage.DenyDeviceAuthorization(context.Background(), "denied")
+ storage.CompleteDeviceAuthorization(context.Background(), "completed", "tim")
+
+ exceededCtx, cancel := context.WithTimeout(context.Background(), -time.Second)
+ defer cancel()
+
+ type args struct {
+ ctx context.Context
+ clientID string
+ deviceCode string
+ }
+ tests := []struct {
+ name string
+ args args
+ want *op.DeviceAuthorizationState
+ wantErr error
+ }{
+ {
+ name: "pending",
+ args: args{
+ ctx: context.Background(),
+ clientID: "native",
+ deviceCode: "pending",
+ },
+ want: &op.DeviceAuthorizationState{
+ ClientID: "native",
+ Scopes: []string{"foo"},
+ Expires: now.Add(time.Minute),
+ },
+ wantErr: oidc.ErrAuthorizationPending(),
+ },
+ {
+ name: "slow down",
+ args: args{
+ ctx: exceededCtx,
+ clientID: "native",
+ deviceCode: "ok",
+ },
+ wantErr: oidc.ErrSlowDown(),
+ },
+ {
+ name: "wrong client",
+ args: args{
+ ctx: context.Background(),
+ clientID: "foo",
+ deviceCode: "ok",
+ },
+ wantErr: oidc.ErrAccessDenied(),
+ },
+ {
+ name: "denied",
+ args: args{
+ ctx: context.Background(),
+ clientID: "native",
+ deviceCode: "denied",
+ },
+ want: &op.DeviceAuthorizationState{
+ ClientID: "native",
+ Scopes: []string{"foo"},
+ Expires: now.Add(time.Minute),
+ Denied: true,
+ },
+ wantErr: oidc.ErrAccessDenied(),
+ },
+ {
+ name: "completed",
+ args: args{
+ ctx: context.Background(),
+ clientID: "native",
+ deviceCode: "completed",
+ },
+ want: &op.DeviceAuthorizationState{
+ ClientID: "native",
+ Scopes: []string{"foo"},
+ Expires: now.Add(time.Minute),
+ Subject: "tim",
+ Done: true,
+ },
+ },
+ {
+ name: "expired",
+ args: args{
+ ctx: context.Background(),
+ clientID: "native",
+ deviceCode: "expired",
+ },
+ want: &op.DeviceAuthorizationState{
+ ClientID: "native",
+ Scopes: []string{"foo"},
+ Expires: now.Add(-time.Minute),
+ },
+ wantErr: oidc.ErrExpiredDeviceCode(),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := op.CheckDeviceAuthorizationState(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, testProvider)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go
index 9a25afc..26f89eb 100644
--- a/pkg/op/discovery.go
+++ b/pkg/op/discovery.go
@@ -44,6 +44,7 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
JwksURI: config.KeysEndpoint().Absolute(issuer),
+ DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
@@ -92,6 +93,9 @@ func GrantTypes(c Configuration) []oidc.GrantType {
if c.GrantTypeJWTAuthorizationSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeBearer)
}
+ if c.GrantTypeDeviceCodeSupported() {
+ grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode)
+ }
return grantTypes
}
diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go
index e1b07dd..2d0b8af 100644
--- a/pkg/op/discovery_test.go
+++ b/pkg/op/discovery_test.go
@@ -131,6 +131,7 @@ func Test_GrantTypes(t *testing.T) {
c.EXPECT().GrantTypeTokenExchangeSupported().Return(false)
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false)
c.EXPECT().GrantTypeClientCredentialsSupported().Return(false)
+ c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
return c
}(),
},
@@ -148,6 +149,7 @@ func Test_GrantTypes(t *testing.T) {
c.EXPECT().GrantTypeTokenExchangeSupported().Return(true)
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true)
c.EXPECT().GrantTypeClientCredentialsSupported().Return(true)
+ c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
return c
}(),
},
diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go
index fc3158a..44b5ceb 100644
--- a/pkg/op/mock/configuration.mock.go
+++ b/pkg/op/mock/configuration.mock.go
@@ -92,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
}
+// DeviceAuthorization mocks base method.
+func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DeviceAuthorization")
+ ret0, _ := ret[0].(op.DeviceAuthorizationConfig)
+ return ret0
+}
+
+// DeviceAuthorization indicates an expected call of DeviceAuthorization.
+func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization))
+}
+
+// DeviceAuthorizationEndpoint mocks base method.
+func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
+ ret0, _ := ret[0].(op.Endpoint)
+ return ret0
+}
+
+// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint.
+func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint))
+}
+
// EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
m.ctrl.T.Helper()
@@ -120,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
}
+// GrantTypeDeviceCodeSupported mocks base method.
+func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported.
+func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported))
+}
+
// GrantTypeJWTAuthorizationSupported mocks base method.
func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
m.ctrl.T.Helper()
@@ -358,6 +400,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
}
+// UserCodeFormEndpoint mocks base method.
+func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "UserCodeFormEndpoint")
+ ret0, _ := ret[0].(op.Endpoint)
+ return ret0
+}
+
+// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint.
+func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint))
+}
+
// UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
m.ctrl.T.Helper()
diff --git a/pkg/op/op.go b/pkg/op/op.go
index 699fb45..2859722 100644
--- a/pkg/op/op.go
+++ b/pkg/op/op.go
@@ -27,17 +27,19 @@ const (
defaultRevocationEndpoint = "revoke"
defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys"
+ defaultDeviceAuthzEndpoint = "/device_authorization"
)
var (
DefaultEndpoints = &endpoints{
- Authorization: NewEndpoint(defaultAuthorizationEndpoint),
- Token: NewEndpoint(defaultTokenEndpoint),
- Introspection: NewEndpoint(defaultIntrospectEndpoint),
- Userinfo: NewEndpoint(defaultUserinfoEndpoint),
- Revocation: NewEndpoint(defaultRevocationEndpoint),
- EndSession: NewEndpoint(defaultEndSessionEndpoint),
- JwksURI: NewEndpoint(defaultKeysEndpoint),
+ Authorization: NewEndpoint(defaultAuthorizationEndpoint),
+ Token: NewEndpoint(defaultTokenEndpoint),
+ Introspection: NewEndpoint(defaultIntrospectEndpoint),
+ Userinfo: NewEndpoint(defaultUserinfoEndpoint),
+ Revocation: NewEndpoint(defaultRevocationEndpoint),
+ EndSession: NewEndpoint(defaultEndSessionEndpoint),
+ JwksURI: NewEndpoint(defaultKeysEndpoint),
+ DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
}
defaultCORSOptions = cors.Options{
@@ -95,6 +97,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
+ router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o))
return router
}
@@ -118,17 +121,19 @@ type Config struct {
GrantTypeRefreshToken bool
RequestObjectSupported bool
SupportedUILocales []language.Tag
+ DeviceAuthorization DeviceAuthorizationConfig
}
type endpoints struct {
- Authorization Endpoint
- Token Endpoint
- Introspection Endpoint
- Userinfo Endpoint
- Revocation Endpoint
- EndSession Endpoint
- CheckSessionIframe Endpoint
- JwksURI Endpoint
+ Authorization Endpoint
+ Token Endpoint
+ Introspection Endpoint
+ Userinfo Endpoint
+ Revocation Endpoint
+ EndSession Endpoint
+ CheckSessionIframe Endpoint
+ JwksURI Endpoint
+ DeviceAuthorization Endpoint
}
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@@ -145,6 +150,7 @@ type endpoints struct {
// /revoke
// /end_session
// /keys
+// /device_authorization
//
// This does not include login. Login is handled with a redirect that includes the
// request ID. The redirect for logins is specified per-client by Client.LoginURL().
@@ -242,6 +248,10 @@ func (o *Provider) EndSessionEndpoint() Endpoint {
return o.endpoints.EndSession
}
+func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
+ return o.endpoints.DeviceAuthorization
+}
+
func (o *Provider) KeysEndpoint() Endpoint {
return o.endpoints.JwksURI
}
@@ -275,6 +285,11 @@ func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
return true
}
+func (o *Provider) GrantTypeDeviceCodeSupported() bool {
+ _, ok := o.storage.(DeviceAuthorizationStorage)
+ return ok
+}
+
func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
return true
}
@@ -308,6 +323,10 @@ func (o *Provider) SupportedUILocales() []language.Tag {
return o.config.SupportedUILocales
}
+func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig {
+ return o.config.DeviceAuthorization
+}
+
func (o *Provider) Storage() Storage {
return o.storage
}
diff --git a/pkg/op/storage.go b/pkg/op/storage.go
index 1e19c76..ebab1c3 100644
--- a/pkg/op/storage.go
+++ b/pkg/op/storage.go
@@ -151,3 +151,50 @@ type EndSessionRequest struct {
ClientID string
RedirectURI string
}
+
+var ErrDuplicateUserCode = errors.New("user code already exists")
+
+type DeviceAuthorizationState struct {
+ ClientID string
+ Scopes []string
+ Expires time.Time
+ Done bool
+ Subject string
+ Denied bool
+}
+
+type DeviceAuthorizationStorage interface {
+ // StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
+ // User code will be used by the user to complete the login flow and must be unique.
+ // ErrDuplicateUserCode signals the caller should try again with a new code.
+ //
+ // Note that user codes are low entropy keys and when many exist in the
+ // database, the change for collisions increases. Therefore implementers
+ // of this interface must make sure that user codes of expired authentication flows are purged,
+ // after some time.
+ StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error
+
+ // GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database.
+ // The method is polled untill the the authorization is eighter Completed, Expired or Denied.
+ GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
+
+ // GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow,
+ // identified by the user code.
+ GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error)
+
+ // CompleteDeviceAuthorization marks a device authorization entry as Completed,
+ // identified by userCode. The Subject is added to the state, so that
+ // GetDeviceAuthorizatonState can use it to create a new Access Token.
+ CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
+
+ // DenyDeviceAuthorization marks a device authorization entry as Denied.
+ DenyDeviceAuthorization(ctx context.Context, userCode string) error
+}
+
+func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) {
+ storage, ok := s.(DeviceAuthorizationStorage)
+ if !ok {
+ return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported")
+ }
+ return storage, nil
+}
diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go
index dfc8954..e7ca7c4 100644
--- a/pkg/op/token_intospection.go
+++ b/pkg/op/token_intospection.go
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"net/http"
- "net/url"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
@@ -50,38 +49,19 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
}
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
- err = r.ParseForm()
+ clientID, authenticated, err := ClientIDFromRequest(r, introspector)
if err != nil {
- return "", "", errors.New("unable to parse request")
+ return "", "", err
}
- req := new(struct {
- oidc.IntrospectionRequest
- oidc.ClientAssertionParams
- })
+ if !authenticated {
+ return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
+ }
+
+ req := new(oidc.IntrospectionRequest)
err = introspector.Decoder().Decode(req, r.Form)
if err != nil {
return "", "", errors.New("unable to parse request")
}
- if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" {
- profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier(r.Context()))
- if err == nil {
- return req.Token, profile.Issuer, nil
- }
- }
- clientID, clientSecret, ok := r.BasicAuth()
- if ok {
- clientID, err = url.QueryUnescape(clientID)
- if err != nil {
- return "", "", errors.New("invalid basic auth header")
- }
- clientSecret, err = url.QueryUnescape(clientSecret)
- if err != nil {
- return "", "", errors.New("invalid basic auth header")
- }
- if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
- return "", "", err
- }
- return req.Token, clientID, nil
- }
- return "", "", errors.New("invalid authorization")
+
+ return req.Token, clientID, nil
}
diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go
index 3d65ea0..b9e9805 100644
--- a/pkg/op/token_request.go
+++ b/pkg/op/token_request.go
@@ -19,6 +19,7 @@ type Exchanger interface {
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool
+ GrantTypeDeviceCodeSupported() bool
AccessTokenVerifier(context.Context) AccessTokenVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
}
@@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ClientCredentialsExchange(w, r, exchanger)
return
}
+ case string(oidc.GrantTypeDeviceCode):
+ if exchanger.GrantTypeDeviceCodeSupported() {
+ DeviceAccessToken(w, r, exchanger)
+ return
+ }
case "":
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
return