implement RFC 8628: Device authorization grant

This commit is contained in:
Tim Möhlmann 2023-03-01 09:59:17 +02:00 committed by GitHub
parent 03f71a67c2
commit 2342f208ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 1968 additions and 97 deletions

View file

@ -0,0 +1,61 @@
package main
import (
"context"
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
)
var (
key = []byte("test1234test1234")
)
func main() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
defer stop()
clientID := os.Getenv("CLIENT_ID")
clientSecret := os.Getenv("CLIENT_SECRET")
keyPath := os.Getenv("KEY_PATH")
issuer := os.Getenv("ISSUER")
scopes := strings.Split(os.Getenv("SCOPES"), " ")
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
var options []rp.Option
if clientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))
}
if keyPath != "" {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
}
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...)
if err != nil {
logrus.Fatalf("error creating provider %s", err.Error())
}
logrus.Info("starting device authorization flow")
resp, err := rp.DeviceAuthorization(scopes, provider)
if err != nil {
logrus.Fatal(err)
}
logrus.Info("resp", resp)
fmt.Printf("\nPlease browse to %s and enter code %s\n", resp.VerificationURI, resp.UserCode)
logrus.Info("start polling")
token, err := rp.DeviceAccessToken(ctx, resp.DeviceCode, time.Duration(resp.Interval)*time.Second, provider)
if err != nil {
logrus.Fatal(err)
}
logrus.Infof("successfully obtained token: %v", token)
}

View file

@ -0,0 +1,191 @@
package exampleop
import (
"errors"
"fmt"
"io"
"net/http"
"net/url"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/op"
)
type deviceAuthenticate interface {
CheckUsernamePasswordSimple(username, password string) error
op.DeviceAuthorizationStorage
}
type deviceLogin struct {
storage deviceAuthenticate
cookie *securecookie.SecureCookie
}
func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) {
l := &deviceLogin{
storage: storage,
cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil),
}
router.HandleFunc("", l.userCodeHandler)
router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler)
router.HandleFunc("/confirm", l.confirmHandler)
}
func renderUserCode(w io.Writer, err error) {
data := struct {
Error string
}{
Error: errMsg(err),
}
if err := templates.ExecuteTemplate(w, "usercode", data); err != nil {
logrus.Error(err)
}
}
func renderDeviceLogin(w http.ResponseWriter, userCode string, err error) {
data := &struct {
UserCode string
Error string
}{
UserCode: userCode,
Error: errMsg(err),
}
if err = templates.ExecuteTemplate(w, "device_login", data); err != nil {
logrus.Error(err)
}
}
func renderConfirmPage(w http.ResponseWriter, username, clientID string, scopes []string) {
data := &struct {
Username string
ClientID string
Scopes []string
}{
Username: username,
ClientID: clientID,
Scopes: scopes,
}
if err := templates.ExecuteTemplate(w, "confirm_device", data); err != nil {
logrus.Error(err)
}
}
func (d *deviceLogin) userCodeHandler(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
renderUserCode(w, err)
return
}
userCode := r.Form.Get("user_code")
if userCode == "" {
if prompt, _ := url.QueryUnescape(r.Form.Get("prompt")); prompt != "" {
err = errors.New(prompt)
}
renderUserCode(w, err)
return
}
renderDeviceLogin(w, userCode, nil)
}
func redirectBack(w http.ResponseWriter, r *http.Request, prompt string) {
values := make(url.Values)
values.Set("prompt", url.QueryEscape(prompt))
url := url.URL{
Path: "/device",
RawQuery: values.Encode(),
}
http.Redirect(w, r, url.String(), http.StatusSeeOther)
}
const userCodeCookieName = "user_code"
type userCodeCookie struct {
UserCode string
UserName string
}
func (d *deviceLogin) loginHandler(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
redirectBack(w, r, err.Error())
return
}
userCode := r.PostForm.Get("user_code")
if userCode == "" {
redirectBack(w, r, "missing user_code in request")
return
}
username := r.PostForm.Get("username")
if username == "" {
redirectBack(w, r, "missing username in request")
return
}
password := r.PostForm.Get("password")
if password == "" {
redirectBack(w, r, "missing password in request")
return
}
if err := d.storage.CheckUsernamePasswordSimple(username, password); err != nil {
redirectBack(w, r, err.Error())
return
}
state, err := d.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
if err != nil {
redirectBack(w, r, err.Error())
return
}
encoded, err := d.cookie.Encode(userCodeCookieName, userCodeCookie{userCode, username})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
cookie := &http.Cookie{
Name: userCodeCookieName,
Value: encoded,
Path: "/",
}
http.SetCookie(w, cookie)
renderConfirmPage(w, username, state.ClientID, state.Scopes)
}
func (d *deviceLogin) confirmHandler(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(userCodeCookieName)
if err != nil {
redirectBack(w, r, err.Error())
return
}
data := new(userCodeCookie)
if err = d.cookie.Decode(userCodeCookieName, cookie.Value, &data); err != nil {
redirectBack(w, r, err.Error())
return
}
if err = r.ParseForm(); err != nil {
redirectBack(w, r, err.Error())
return
}
action := r.Form.Get("action")
switch action {
case "allowed":
err = d.storage.CompleteDeviceAuthorization(r.Context(), data.UserCode, data.UserName)
case "denied":
err = d.storage.DenyDeviceAuthorization(r.Context(), data.UserCode)
default:
err = errors.New("action must be one of \"allow\" or \"deny\"")
}
if err != nil {
redirectBack(w, r, err.Error())
return
}
fmt.Fprintf(w, "Device authorization %s. You can now return to the device", action)
}

View file

@ -3,45 +3,11 @@ package exampleop
import (
"context"
"fmt"
"html/template"
"net/http"
"github.com/gorilla/mux"
)
const (
queryAuthRequestID = "authRequestID"
)
var loginTmpl, _ = template.New("login").Parse(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Login</title>
</head>
<body style="display: flex; align-items: center; justify-content: center; height: 100vh;">
<form method="POST" action="/login/username" style="height: 200px; width: 200px;">
<input type="hidden" name="id" value="{{.ID}}">
<div>
<label for="username">Username:</label>
<input id="username" name="username" style="width: 100%">
</div>
<div>
<label for="password">Password:</label>
<input id="password" name="password" style="width: 100%">
</div>
<p style="color:red; min-height: 1rem;">{{.Error}}</p>
<button type="submit">Login</button>
</form>
</body>
</html>`)
type login struct {
authenticate authenticate
router *mux.Router
@ -79,18 +45,14 @@ func (l *login) loginHandler(w http.ResponseWriter, r *http.Request) {
}
func renderLogin(w http.ResponseWriter, id string, err error) {
var errMsg string
if err != nil {
errMsg = err.Error()
}
data := &struct {
ID string
Error string
}{
ID: id,
Error: errMsg,
Error: errMsg(err),
}
err = loginTmpl.Execute(w, data)
err = templates.ExecuteTemplate(w, "login", data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}

View file

@ -5,6 +5,7 @@ import (
"crypto/sha256"
"log"
"net/http"
"time"
"github.com/gorilla/mux"
"golang.org/x/text/language"
@ -27,7 +28,8 @@ func init() {
type Storage interface {
op.Storage
CheckUsernamePassword(username, password, id string) error
authenticate
deviceAuthenticate
}
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
@ -62,6 +64,9 @@ func SetupServer(ctx context.Context, issuer string, storage Storage) *mux.Route
// so we will direct all calls to /login to the login UI
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router))
router.PathPrefix("/device").Subrouter()
registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter())
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
// is served on the correct path
//
@ -99,6 +104,13 @@ func newOP(ctx context.Context, storage op.Storage, issuer string, key [32]byte)
// this example has only static texts (in English), so we'll set the here accordingly
SupportedUILocales: []language.Tag{language.English},
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: issuer + "device",
UserCode: op.UserCodeBase20,
},
}
handler, err := op.NewOpenIDProvider(ctx, issuer, config, storage,
//we must explicitly allow the use of the http issuer

View file

@ -0,0 +1,26 @@
package exampleop
import (
"embed"
"html/template"
"github.com/sirupsen/logrus"
)
var (
//go:embed templates
templateFS embed.FS
templates = template.Must(template.ParseFS(templateFS, "templates/*.html"))
)
const (
queryAuthRequestID = "authRequestID"
)
func errMsg(err error) string {
if err == nil {
return ""
}
logrus.Error(err)
return err.Error()
}

View file

@ -0,0 +1,25 @@
{{ define "confirm_device" -}}
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Confirm device authorization</title>
<style>
.green{
background-color: green
}
.red{
background-color: red
}
</style>
</head>
<body>
<h1>Welcome back {{.Username}}!</h1>
<p>
You are about to grant device {{.ClientID}} access to the following scopes: {{.Scopes}}.
</p>
<button onclick="location.href='./confirm?action=allowed'" type="button" class="green">Allow</button>
<button onclick="location.href='./confirm?action=denied'" type="button" class="red">Deny</button>
</body>
</html>
{{- end }}

View file

@ -0,0 +1,29 @@
{{ define "device_login" -}}
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Login</title>
</head>
<body style="display: flex; align-items: center; justify-content: center; height: 100vh;">
<form method="POST" action="/device/login" style="height: 200px; width: 200px;">
<input type="hidden" name="user_code" value="{{.UserCode}}">
<div>
<label for="username">Username:</label>
<input id="username" name="username" style="width: 100%">
</div>
<div>
<label for="password">Password:</label>
<input id="password" name="password" style="width: 100%">
</div>
<p style="color:red; min-height: 1rem;">{{.Error}}</p>
<button type="submit">Login</button>
</form>
</body>
</html>
{{- end }}

View file

@ -0,0 +1,29 @@
{{ define "login" -}}
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Login</title>
</head>
<body style="display: flex; align-items: center; justify-content: center; height: 100vh;">
<form method="POST" action="/login/username" style="height: 200px; width: 200px;">
<input type="hidden" name="id" value="{{.ID}}">
<div>
<label for="username">Username:</label>
<input id="username" name="username" style="width: 100%">
</div>
<div>
<label for="password">Password:</label>
<input id="password" name="password" style="width: 100%">
</div>
<p style="color:red; min-height: 1rem;">{{.Error}}</p>
<button type="submit">Login</button>
</form>
</body>
</html>`
{{- end }}

View file

@ -0,0 +1,21 @@
{{ define "usercode" -}}
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Device authorization</title>
</head>
<body style="display: flex; align-items: center; justify-content: center; height: 100vh;">
<form method="POST" style="height: 200px; width: 200px;">
<h1>Device authorization</h1>
<div>
<label for="user_code">Code:</label>
<input id="user_code" name="user_code" style="width: 100%">
</div>
<p style="color:red; min-height: 1rem;">{{.Error}}</p>
<button type="submit">Login</button>
</form>
</body>
</html>
{{- end }}

View file

@ -44,6 +44,8 @@ type Storage struct {
services map[string]Service
refreshTokens map[string]*RefreshToken
signingKey signingKey
deviceCodes map[string]deviceAuthorizationEntry
userCodes map[string]string
}
type signingKey struct {
@ -105,6 +107,8 @@ func NewStorage(userStore UserStore) *Storage {
algorithm: jose.RS256,
key: key,
},
deviceCodes: make(map[string]deviceAuthorizationEntry),
userCodes: make(map[string]string),
}
}
@ -135,6 +139,17 @@ func (s *Storage) CheckUsernamePassword(username, password, id string) error {
return fmt.Errorf("username or password wrong")
}
func (s *Storage) CheckUsernamePasswordSimple(username, password string) error {
s.lock.Lock()
defer s.lock.Unlock()
user := s.userStore.GetUserByUsername(username)
if user != nil && user.Password == password {
return nil
}
return fmt.Errorf("username or password wrong")
}
// CreateAuthRequest implements the op.Storage interface
// it will be called after parsing and validation of the authentication request
func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
@ -735,3 +750,85 @@ func appendClaim(claims map[string]interface{}, claim string, value interface{})
claims[claim] = value
return claims
}
type deviceAuthorizationEntry struct {
deviceCode string
userCode string
state *op.DeviceAuthorizationState
}
func (s *Storage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error {
s.lock.Lock()
defer s.lock.Unlock()
if _, ok := s.clients[clientID]; !ok {
return errors.New("client not found")
}
if _, ok := s.userCodes[userCode]; ok {
return op.ErrDuplicateUserCode
}
s.deviceCodes[deviceCode] = deviceAuthorizationEntry{
deviceCode: deviceCode,
userCode: userCode,
state: &op.DeviceAuthorizationState{
ClientID: clientID,
Scopes: scopes,
Expires: expires,
},
}
s.userCodes[userCode] = deviceCode
return nil
}
func (s *Storage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.lock.Lock()
defer s.lock.Unlock()
entry, ok := s.deviceCodes[deviceCode]
if !ok || entry.state.ClientID != clientID {
return nil, errors.New("device code not found for client") // is there a standard not found error in the framework?
}
return entry.state, nil
}
func (s *Storage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) {
s.lock.Lock()
defer s.lock.Unlock()
entry, ok := s.deviceCodes[s.userCodes[userCode]]
if !ok {
return nil, errors.New("user code not found")
}
return entry.state, nil
}
func (s *Storage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error {
s.lock.Lock()
defer s.lock.Unlock()
entry, ok := s.deviceCodes[s.userCodes[userCode]]
if !ok {
return errors.New("user code not found")
}
entry.state.Subject = subject
entry.state.Done = true
return nil
}
func (s *Storage) DenyDeviceAuthorization(ctx context.Context, userCode string) error {
s.lock.Lock()
defer s.lock.Unlock()
s.deviceCodes[s.userCodes[userCode]].state.Denied = true
return nil
}

View file

@ -1,6 +1,8 @@
package client
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@ -186,3 +188,94 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
IssuedAt: oidc.Time(iat),
}, signer)
}
type DeviceAuthorizationCaller interface {
GetDeviceAuthorizationEndpoint() string
HttpClient() *http.Client
}
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
if err != nil {
return nil, err
}
if request.ClientSecret != "" {
req.SetBasicAuth(request.ClientID, request.ClientSecret)
}
resp := new(oidc.DeviceAuthorizationResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
return nil, err
}
return resp, nil
}
type DeviceAccessTokenRequest struct {
*oidc.ClientCredentialsRequest
oidc.DeviceAccessTokenRequest
}
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
if err != nil {
return nil, err
}
if request.ClientSecret != "" {
req.SetBasicAuth(request.ClientID, request.ClientSecret)
}
httpResp, err := caller.HttpClient().Do(req)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
resp := new(struct {
*oidc.AccessTokenResponse
*oidc.Error
})
if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil {
return nil, err
}
if httpResp.StatusCode == http.StatusOK {
return resp.AccessTokenResponse, nil
}
return nil, resp.Error
}
func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
for {
timer := time.After(interval)
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer:
}
ctx, cancel := context.WithTimeout(ctx, interval)
defer cancel()
resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
if err == nil {
return resp, nil
}
if errors.Is(err, context.DeadlineExceeded) {
interval += 5 * time.Second
}
var target *oidc.Error
if !errors.As(err, &target) {
return nil, err
}
switch target.ErrorType {
case oidc.AuthorizationPending:
continue
case oidc.SlowDown:
interval += 5 * time.Second
continue
default:
return nil, err
}
}
}

62
pkg/client/rp/device.go Normal file
View file

@ -0,0 +1,62 @@
package rp
import (
"context"
"fmt"
"time"
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
confg := rp.OAuthConfig()
req := &oidc.ClientCredentialsRequest{
GrantType: oidc.GrantTypeDeviceCode,
Scope: scopes,
ClientID: confg.ClientID,
ClientSecret: confg.ClientSecret,
}
if signer := rp.Signer(); signer != nil {
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
if err != nil {
return nil, fmt.Errorf("failed to build assertion: %w", err)
}
req.ClientAssertion = assertion
req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
}
return req, nil
}
// DeviceAuthorization starts a new Device Authorization flow as defined
// in RFC 8628, section 3.1 and 3.2:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
req, err := newDeviceClientCredentialsRequest(scopes, rp)
if err != nil {
return nil, err
}
return client.CallDeviceAuthorizationEndpoint(req, rp)
}
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
// by means of polling as defined in RFC, section 3.3 and 3.4:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
req := &client.DeviceAccessTokenRequest{
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
GrantType: oidc.GrantTypeDeviceCode,
DeviceCode: deviceCode,
},
}
req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
if err != nil {
return nil, err
}
return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
}

View file

@ -59,6 +59,10 @@ type RelyingParty interface {
// UserinfoEndpoint returns the userinfo
UserinfoEndpoint() string
// GetDeviceAuthorizationEndpoint returns the enpoint which can
// be used to start a DeviceAuthorization flow.
GetDeviceAuthorizationEndpoint() string
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier
// ErrorHandler returns the handler used for callback errors
@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
return rp.endpoints.UserinfoURL
}
func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
return rp.endpoints.DeviceAuthorizationURL
}
func (rp *relyingParty) GetEndSessionEndpoint() string {
return rp.endpoints.EndSessionURL
}
@ -500,6 +508,7 @@ type Endpoints struct {
JKWsURL string
EndSessionURL string
RevokeURL string
DeviceAuthorizationURL string
}
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
@ -514,6 +523,7 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
JKWsURL: discoveryConfig.JwksURI,
EndSessionURL: discoveryConfig.EndSessionEndpoint,
RevokeURL: discoveryConfig.RevocationEndpoint,
DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
}
}

View file

@ -0,0 +1,29 @@
package oidc
// DeviceAuthorizationRequest implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1,
// 3.1 Device Authorization Request.
type DeviceAuthorizationRequest struct {
Scopes SpaceDelimitedArray `schema:"scope"`
ClientID string `schema:"client_id"`
}
// DeviceAuthorizationResponse implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.2
// 3.2. Device Authorization Response.
type DeviceAuthorizationResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval,omitempty"`
}
// DeviceAccessTokenRequest implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4,
// Device Access Token Request.
type DeviceAccessTokenRequest struct {
GrantType GrantType `json:"grant_type" schema:"grant_type"`
DeviceCode string `json:"device_code" schema:"device_code"`
}

View file

@ -30,6 +30,8 @@ type DiscoveryConfiguration struct {
// EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"`
// CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
CheckSessionIframe string `json:"check_session_iframe,omitempty"`

View file

@ -18,6 +18,14 @@ const (
InteractionRequired errorType = "interaction_required"
LoginRequired errorType = "login_required"
RequestNotSupported errorType = "request_not_supported"
// Additional error codes as defined in
// https://www.rfc-editor.org/rfc/rfc8628#section-3.5
// Device Access Token Response
AuthorizationPending errorType = "authorization_pending"
SlowDown errorType = "slow_down"
AccessDenied errorType = "access_denied"
ExpiredToken errorType = "expired_token"
)
var (
@ -77,6 +85,32 @@ var (
ErrorType: RequestNotSupported,
}
}
// Device Access Token errors:
ErrAuthorizationPending = func() *Error {
return &Error{
ErrorType: AuthorizationPending,
Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.",
}
}
ErrSlowDown = func() *Error {
return &Error{
ErrorType: SlowDown,
Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.",
}
}
ErrAccessDenied = func() *Error {
return &Error{
ErrorType: AccessDenied,
Description: "The authorization request was denied.",
}
}
ErrExpiredDeviceCode = func() *Error {
return &Error{
ErrorType: ExpiredToken,
Description: "The \"device_code\" has expired.",
}
}
)
type Error struct {

View file

@ -27,6 +27,9 @@ const (
// GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code
GrantTypeImplicit GrantType = "implicit"
// GrantTypeDeviceCode
GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code"
// ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
// used for the OAuth JWT Profile Client Authentication
ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
@ -35,7 +38,7 @@ const (
var AllGrantTypes = []GrantType{
GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials,
GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit,
ClientAssertionTypeJWTAssertion,
GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion,
}
type GrantType string

View file

@ -1,8 +1,13 @@
package op
import (
"context"
"errors"
"net/http"
"net/url"
"time"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
@ -57,3 +62,95 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT
func IsConfidentialType(c Client) bool {
return c.ApplicationType() == ApplicationTypeWeb
}
var (
ErrInvalidAuthHeader = errors.New("invalid basic auth header")
ErrNoClientCredentials = errors.New("no client credentials provided")
ErrMissingClientID = errors.New("client_id missing from request")
)
type ClientJWTProfile interface {
JWTProfileVerifier(context.Context) JWTProfileVerifier
}
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
if ca.ClientAssertion == "" {
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
}
profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx))
if err != nil {
return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed")
}
return profile.Issuer, nil
}
func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) {
clientID, clientSecret, ok := r.BasicAuth()
if !ok {
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
}
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
}
if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
return "", oidc.ErrUnauthorizedClient().WithParent(err)
}
return clientID, nil
}
type ClientProvider interface {
Decoder() httphelper.Decoder
Storage() Storage
}
type clientData struct {
ClientID string `schema:"client_id"`
oidc.ClientAssertionParams
}
// ClientIDFromRequest parses the request form and tries to obtain the client ID
// and reports if it is authenticated, using a JWT or static client secrets over
// http basic auth.
//
// If the Provider implements IntrospectorJWTProfile and "client_assertion" is
// present in the form data, JWT assertion will be verified and the
// client ID is taken from there.
// If any of them is absent, basic auth is attempted.
// In absence of basic auth data, the unauthenticated client id from the form
// data is returned.
//
// If no client id can be obtained by any method, oidc.ErrInvalidClient
// is returned with ErrMissingClientID wrapped in it.
func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) {
err = r.ParseForm()
if err != nil {
return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
data := new(clientData)
if err = p.Decoder().Decode(data, r.PostForm); err != nil {
return "", false, err
}
JWTProfile, ok := p.(ClientJWTProfile)
if ok {
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
}
if !ok || errors.Is(err, ErrNoClientCredentials) {
clientID, err = ClientBasicAuth(r, p.Storage())
}
if err == nil {
return clientID, true, nil
}
if data.ClientID == "" {
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
}
return data.ClientID, false, nil
}

253
pkg/op/client_test.go Normal file
View file

@ -0,0 +1,253 @@
package op_test
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
)
type testClientJWTProfile struct{}
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
func TestClientJWTAuth(t *testing.T) {
type args struct {
ctx context.Context
ca oidc.ClientAssertionParams
verifier op.ClientJWTProfile
}
tests := []struct {
name string
args args
wantClientID string
wantErr error
}{
{
name: "empty assertion",
args: args{
context.Background(),
oidc.ClientAssertionParams{},
testClientJWTProfile{},
},
wantErr: op.ErrNoClientCredentials,
},
{
name: "verification error",
args: args{
context.Background(),
oidc.ClientAssertionParams{
ClientAssertion: "foo",
},
testClientJWTProfile{},
},
wantErr: oidc.ErrParse,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantClientID, gotClientID)
})
}
}
func TestClientBasicAuth(t *testing.T) {
errWrong := errors.New("wrong secret")
type args struct {
username string
password string
}
tests := []struct {
name string
args *args
storage op.Storage
wantClientID string
wantErr error
}{
{
name: "no args",
wantErr: op.ErrNoClientCredentials,
},
{
name: "username unescape err",
args: &args{
username: "%",
password: "bar",
},
wantErr: op.ErrInvalidAuthHeader,
},
{
name: "password unescape err",
args: &args{
username: "foo",
password: "%",
},
wantErr: op.ErrInvalidAuthHeader,
},
{
name: "auth error",
args: &args{
username: "foo",
password: "wrong",
},
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong)
return s
}(),
wantErr: errWrong,
},
{
name: "auth error",
args: &args{
username: "foo",
password: "bar",
},
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
return s
}(),
wantClientID: "foo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/foo", nil)
if tt.args != nil {
r.SetBasicAuth(tt.args.username, tt.args.password)
}
gotClientID, err := op.ClientBasicAuth(r, tt.storage)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantClientID, gotClientID)
})
}
}
type errReader struct{}
func (errReader) Read([]byte) (int, error) {
return 0, io.ErrNoProgress
}
type testClientProvider struct {
storage op.Storage
}
func (testClientProvider) Decoder() httphelper.Decoder {
return schema.NewDecoder()
}
func (p testClientProvider) Storage() op.Storage {
return p.storage
}
func TestClientIDFromRequest(t *testing.T) {
type args struct {
body io.Reader
p op.ClientProvider
}
type basicAuth struct {
username string
password string
}
tests := []struct {
name string
args args
basicAuth *basicAuth
wantClientID string
wantAuthenticated bool
wantErr bool
}{
{
name: "parse error",
args: args{
body: errReader{},
},
wantErr: true,
},
{
name: "unauthenticated",
args: args{
body: strings.NewReader(
url.Values{
"client_id": []string{"foo"},
}.Encode(),
),
p: testClientProvider{
storage: mock.NewStorage(t),
},
},
wantClientID: "foo",
wantAuthenticated: false,
},
{
name: "authenticated",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
return s
}(),
},
},
basicAuth: &basicAuth{
username: "foo",
password: "bar",
},
wantClientID: "foo",
wantAuthenticated: true,
},
{
name: "missing client id",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: mock.NewStorage(t),
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if tt.basicAuth != nil {
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.wantClientID, gotClientID)
assert.Equal(t, tt.wantAuthenticated, gotAuthenticated)
})
}
}

View file

@ -27,6 +27,7 @@ type Configuration interface {
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool
@ -36,6 +37,7 @@ type Configuration interface {
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
IntrospectionEndpointSigningAlgorithmsSupported() []string
RevocationAuthMethodPrivateKeyJWTSupported() bool
@ -44,6 +46,7 @@ type Configuration interface {
RequestObjectSigningAlgorithmsSupported() []string
SupportedUILocales() []language.Tag
DeviceAuthorization() DeviceAuthorizationConfig
}
type IssuerFromRequest func(r *http.Request) string

265
pkg/op/device.go Normal file
View file

@ -0,0 +1,265 @@
package op
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type DeviceAuthorizationConfig struct {
Lifetime time.Duration
PollInterval time.Duration
UserFormURL string // the URL where the user must go to authorize the device
UserCode UserCodeConfig
}
type UserCodeConfig struct {
CharSet string
CharAmount int
DashInterval int
}
const (
CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
CharSetDigits = "0123456789"
)
var (
UserCodeBase20 = UserCodeConfig{
CharSet: CharSetBase20,
CharAmount: 8,
DashInterval: 4,
}
UserCodeDigits = UserCodeConfig{
CharSet: CharSetDigits,
CharAmount: 9,
DashInterval: 3,
}
)
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if err := DeviceAuthorization(w, r, o); err != nil {
RequestError(w, r, err)
}
}
}
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return err
}
req, err := ParseDeviceCodeRequest(r, o)
if err != nil {
return err
}
config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil {
return err
}
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
if err != nil {
return err
}
expires := time.Now().Add(config.Lifetime)
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
if err != nil {
return err
}
response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: config.UserFormURL,
ExpiresIn: int(config.Lifetime / time.Second),
Interval: int(config.PollInterval / time.Second),
}
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
httphelper.MarshalJSON(w, response)
return nil
}
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
clientID, _, err := ClientIDFromRequest(r, o)
if err != nil {
return nil, err
}
req := new(oidc.DeviceAuthorizationRequest)
if err := o.Decoder().Decode(req, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
}
req.ClientID = clientID
return req, nil
}
// 16 bytes gives 128 bit of entropy.
// results in a 22 character base64 encoded string.
const RecommendedDeviceCodeBytes = 16
func NewDeviceCode(nBytes int) (string, error) {
bytes := make([]byte, nBytes)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("%w getting entropy for device code", err)
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
var buf strings.Builder
if dashInterval > 0 {
buf.Grow(charAmount + charAmount/dashInterval - 1)
} else {
buf.Grow(charAmount)
}
max := big.NewInt(int64(len(charSet)))
for i := 0; i < charAmount; i++ {
if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
buf.WriteByte('-')
}
bi, err := rand.Int(rand.Reader, max)
if err != nil {
return "", fmt.Errorf("%w getting entropy for user code", err)
}
buf.WriteRune(charSet[int(bi.Int64())])
}
return buf.String(), nil
}
type deviceAccessTokenRequest struct {
subject string
audience []string
scopes []string
}
func (r *deviceAccessTokenRequest) GetSubject() string {
return r.subject
}
func (r *deviceAccessTokenRequest) GetAudience() []string {
return r.audience
}
func (r *deviceAccessTokenRequest) GetScopes() []string {
return r.scopes
}
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
if err := deviceAccessToken(w, r, exchanger); err != nil {
RequestError(w, r, err)
}
}
func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
// use a limited context timeout shorter as the default
// poll interval of 5 seconds.
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
defer cancel()
r = r.WithContext(ctx)
clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger)
if err != nil {
return err
}
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
if err != nil {
return err
}
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
if err != nil {
return err
}
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
if err != nil {
return err
}
if clientAuthenticated != IsConfidentialType(client) {
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
WithDescription("confidential client requires authentication")
}
tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{clientID},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
if err != nil {
return err
}
httphelper.MarshalJSON(w, resp)
return nil
}
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
req := new(oidc.DeviceAccessTokenRequest)
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
return nil, err
}
return req, nil
}
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
storage, err := assertDeviceStorage(exchanger.Storage())
if err != nil {
return nil, err
}
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err)
}
if err != nil {
return nil, oidc.ErrAccessDenied().WithParent(err)
}
if state.Denied {
return state, oidc.ErrAccessDenied()
}
if state.Done {
return state, nil
}
if time.Now().After(state.Expires) {
return state, oidc.ErrExpiredDeviceCode()
}
return state, oidc.ErrAuthorizationPending()
}
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) {
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, client, "")
if err != nil {
return nil, err
}
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
}, nil
}

453
pkg/op/device_test.go Normal file
View file

@ -0,0 +1,453 @@
package op_test
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"io"
mr "math/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"golang.org/x/text/language"
)
var testProvider op.OpenIDProvider
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
config := &op.Config{
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true,
AuthMethodPost: true,
AuthMethodPrivateKeyJWT: true,
GrantTypeRefreshToken: true,
RequestObjectSupported: true,
SupportedUILocales: []language.Tag{language.English},
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: testIssuer + "device",
UserCode: op.UserCodeBase20,
},
}
storage.RegisterClients(
storage.NativeClient("native"),
storage.WebClient("web", "secret"),
storage.WebClient("api", "secret"),
)
var err error
testProvider, err = op.NewOpenIDProvider(context.TODO(), testIssuer, config,
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
)
if err != nil {
panic(err)
}
}
func Test_deviceAuthorizationHandler(t *testing.T) {
req := &oidc.DeviceAuthorizationRequest{
Scopes: []string{"foo", "bar"},
ClientID: "web",
}
values := make(url.Values)
testProvider.Encoder().Encode(req, values)
body := strings.NewReader(values.Encode())
r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
runWithRandReader(mr.New(mr.NewSource(1)), func() {
op.DeviceAuthorizationHandler(testProvider)(w, r)
})
result := w.Result()
assert.Less(t, result.StatusCode, 300)
got, _ := io.ReadAll(result.Body)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
}
func TestParseDeviceCodeRequest(t *testing.T) {
tests := []struct {
name string
req *oidc.DeviceAuthorizationRequest
wantErr bool
}{
{
name: "empty request",
wantErr: true,
},
/* decoding a SpaceDelimitedArray is broken
https://github.com/zitadel/oidc/issues/295
{
name: "success",
req: &oidc.DeviceAuthorizationRequest{
Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
ClientID: "web",
},
},
*/
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var body io.Reader
if tt.req != nil {
values := make(url.Values)
testProvider.Encoder().Encode(tt.req, values)
body = strings.NewReader(values.Encode())
}
r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
got, err := op.ParseDeviceCodeRequest(r, testProvider)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.req, got)
})
}
}
func runWithRandReader(r io.Reader, f func()) {
originalReader := rand.Reader
rand.Reader = r
defer func() {
rand.Reader = originalReader
}()
f()
}
func TestNewDeviceCode(t *testing.T) {
t.Run("reader error", func(t *testing.T) {
runWithRandReader(errReader{}, func() {
_, err := op.NewDeviceCode(16)
require.Error(t, err)
})
})
t.Run("different lengths, rand reader", func(t *testing.T) {
for i := 1; i <= 32; i++ {
got, err := op.NewDeviceCode(i)
require.NoError(t, err)
assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
}
})
}
func TestNewUserCode(t *testing.T) {
type args struct {
charset []rune
charAmount int
dashInterval int
}
tests := []struct {
name string
args args
reader io.Reader
want string
wantErr bool
}{
{
name: "reader error",
args: args{
charset: []rune(op.CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: errReader{},
wantErr: true,
},
{
name: "base20",
args: args{
charset: []rune(op.CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: mr.New(mr.NewSource(1)),
want: "XKCD-HTTD",
},
{
name: "digits",
args: args{
charset: []rune(op.CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: mr.New(mr.NewSource(1)),
want: "271-256-225",
},
{
name: "no dashes",
args: args{
charset: []rune(op.CharSetDigits),
charAmount: 9,
},
reader: mr.New(mr.NewSource(1)),
want: "271256225",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
runWithRandReader(tt.reader, func() {
got, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
if tt.wantErr {
require.ErrorIs(t, err, io.ErrNoProgress)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
})
}
t.Run("crypto/rand", func(t *testing.T) {
const testN = 100000
for _, c := range []op.UserCodeConfig{op.UserCodeBase20, op.UserCodeDigits} {
t.Run(c.CharSet, func(t *testing.T) {
results := make(map[string]int)
for i := 0; i < testN; i++ {
code, err := op.NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval)
require.NoError(t, err)
results[code]++
}
t.Log(results)
var duplicates int
for code, count := range results {
assert.Less(t, count, 3, code)
if count == 2 {
duplicates++
}
}
})
}
})
}
func BenchmarkNewUserCode(b *testing.B) {
type args struct {
charset []rune
charAmount int
dashInterval int
}
tests := []struct {
name string
args args
reader io.Reader
}{
{
name: "math rand, base20",
args: args{
charset: []rune(op.CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: mr.New(mr.NewSource(1)),
},
{
name: "math rand, digits",
args: args{
charset: []rune(op.CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: mr.New(mr.NewSource(1)),
},
{
name: "crypto rand, base20",
args: args{
charset: []rune(op.CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: rand.Reader,
},
{
name: "crypto rand, digits",
args: args{
charset: []rune(op.CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: rand.Reader,
},
}
for _, tt := range tests {
runWithRandReader(tt.reader, func() {
b.Run(tt.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
require.NoError(b, err)
}
})
})
}
}
func TestDeviceAccessToken(t *testing.T) {
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
values := make(url.Values)
values.Set("client_id", "native")
values.Set("grant_type", string(oidc.GrantTypeDeviceCode))
values.Set("device_code", "qwerty")
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
op.DeviceAccessToken(w, r, testProvider)
result := w.Result()
got, _ := io.ReadAll(result.Body)
t.Log(string(got))
assert.Less(t, result.StatusCode, 300)
assert.NotEmpty(t, string(got))
}
func TestCheckDeviceAuthorizationState(t *testing.T) {
now := time.Now()
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "expired", "expired", now.Add(-time.Minute), []string{"foo"})
storage.DenyDeviceAuthorization(context.Background(), "denied")
storage.CompleteDeviceAuthorization(context.Background(), "completed", "tim")
exceededCtx, cancel := context.WithTimeout(context.Background(), -time.Second)
defer cancel()
type args struct {
ctx context.Context
clientID string
deviceCode string
}
tests := []struct {
name string
args args
want *op.DeviceAuthorizationState
wantErr error
}{
{
name: "pending",
args: args{
ctx: context.Background(),
clientID: "native",
deviceCode: "pending",
},
want: &op.DeviceAuthorizationState{
ClientID: "native",
Scopes: []string{"foo"},
Expires: now.Add(time.Minute),
},
wantErr: oidc.ErrAuthorizationPending(),
},
{
name: "slow down",
args: args{
ctx: exceededCtx,
clientID: "native",
deviceCode: "ok",
},
wantErr: oidc.ErrSlowDown(),
},
{
name: "wrong client",
args: args{
ctx: context.Background(),
clientID: "foo",
deviceCode: "ok",
},
wantErr: oidc.ErrAccessDenied(),
},
{
name: "denied",
args: args{
ctx: context.Background(),
clientID: "native",
deviceCode: "denied",
},
want: &op.DeviceAuthorizationState{
ClientID: "native",
Scopes: []string{"foo"},
Expires: now.Add(time.Minute),
Denied: true,
},
wantErr: oidc.ErrAccessDenied(),
},
{
name: "completed",
args: args{
ctx: context.Background(),
clientID: "native",
deviceCode: "completed",
},
want: &op.DeviceAuthorizationState{
ClientID: "native",
Scopes: []string{"foo"},
Expires: now.Add(time.Minute),
Subject: "tim",
Done: true,
},
},
{
name: "expired",
args: args{
ctx: context.Background(),
clientID: "native",
deviceCode: "expired",
},
want: &op.DeviceAuthorizationState{
ClientID: "native",
Scopes: []string{"foo"},
Expires: now.Add(-time.Minute),
},
wantErr: oidc.ErrExpiredDeviceCode(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := op.CheckDeviceAuthorizationState(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, testProvider)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -44,6 +44,7 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
JwksURI: config.KeysEndpoint().Absolute(issuer),
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
@ -92,6 +93,9 @@ func GrantTypes(c Configuration) []oidc.GrantType {
if c.GrantTypeJWTAuthorizationSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeBearer)
}
if c.GrantTypeDeviceCodeSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode)
}
return grantTypes
}

View file

@ -131,6 +131,7 @@ func Test_GrantTypes(t *testing.T) {
c.EXPECT().GrantTypeTokenExchangeSupported().Return(false)
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false)
c.EXPECT().GrantTypeClientCredentialsSupported().Return(false)
c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
return c
}(),
},
@ -148,6 +149,7 @@ func Test_GrantTypes(t *testing.T) {
c.EXPECT().GrantTypeTokenExchangeSupported().Return(true)
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true)
c.EXPECT().GrantTypeClientCredentialsSupported().Return(true)
c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
return c
}(),
},

View file

@ -92,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
}
// DeviceAuthorization mocks base method.
func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorization")
ret0, _ := ret[0].(op.DeviceAuthorizationConfig)
return ret0
}
// DeviceAuthorization indicates an expected call of DeviceAuthorization.
func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization))
}
// DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint.
func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint))
}
// EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
m.ctrl.T.Helper()
@ -120,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
}
// GrantTypeDeviceCodeSupported mocks base method.
func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported.
func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported))
}
// GrantTypeJWTAuthorizationSupported mocks base method.
func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
m.ctrl.T.Helper()
@ -358,6 +400,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
}
// UserCodeFormEndpoint mocks base method.
func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserCodeFormEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint.
func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint))
}
// UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
m.ctrl.T.Helper()

View file

@ -27,6 +27,7 @@ const (
defaultRevocationEndpoint = "revoke"
defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys"
defaultDeviceAuthzEndpoint = "/device_authorization"
)
var (
@ -38,6 +39,7 @@ var (
Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint),
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
}
defaultCORSOptions = cors.Options{
@ -95,6 +97,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o))
return router
}
@ -118,6 +121,7 @@ type Config struct {
GrantTypeRefreshToken bool
RequestObjectSupported bool
SupportedUILocales []language.Tag
DeviceAuthorization DeviceAuthorizationConfig
}
type endpoints struct {
@ -129,6 +133,7 @@ type endpoints struct {
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
DeviceAuthorization Endpoint
}
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -145,6 +150,7 @@ type endpoints struct {
// /revoke
// /end_session
// /keys
// /device_authorization
//
// This does not include login. Login is handled with a redirect that includes the
// request ID. The redirect for logins is specified per-client by Client.LoginURL().
@ -242,6 +248,10 @@ func (o *Provider) EndSessionEndpoint() Endpoint {
return o.endpoints.EndSession
}
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
return o.endpoints.DeviceAuthorization
}
func (o *Provider) KeysEndpoint() Endpoint {
return o.endpoints.JwksURI
}
@ -275,6 +285,11 @@ func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
return true
}
func (o *Provider) GrantTypeDeviceCodeSupported() bool {
_, ok := o.storage.(DeviceAuthorizationStorage)
return ok
}
func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
return true
}
@ -308,6 +323,10 @@ func (o *Provider) SupportedUILocales() []language.Tag {
return o.config.SupportedUILocales
}
func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig {
return o.config.DeviceAuthorization
}
func (o *Provider) Storage() Storage {
return o.storage
}

View file

@ -151,3 +151,50 @@ type EndSessionRequest struct {
ClientID string
RedirectURI string
}
var ErrDuplicateUserCode = errors.New("user code already exists")
type DeviceAuthorizationState struct {
ClientID string
Scopes []string
Expires time.Time
Done bool
Subject string
Denied bool
}
type DeviceAuthorizationStorage interface {
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
// User code will be used by the user to complete the login flow and must be unique.
// ErrDuplicateUserCode signals the caller should try again with a new code.
//
// Note that user codes are low entropy keys and when many exist in the
// database, the change for collisions increases. Therefore implementers
// of this interface must make sure that user codes of expired authentication flows are purged,
// after some time.
StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error
// GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database.
// The method is polled untill the the authorization is eighter Completed, Expired or Denied.
GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
// GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow,
// identified by the user code.
GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error)
// CompleteDeviceAuthorization marks a device authorization entry as Completed,
// identified by userCode. The Subject is added to the state, so that
// GetDeviceAuthorizatonState can use it to create a new Access Token.
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
// DenyDeviceAuthorization marks a device authorization entry as Denied.
DenyDeviceAuthorization(ctx context.Context, userCode string) error
}
func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) {
storage, ok := s.(DeviceAuthorizationStorage)
if !ok {
return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported")
}
return storage, nil
}

View file

@ -4,7 +4,6 @@ import (
"context"
"errors"
"net/http"
"net/url"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
@ -50,38 +49,19 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
}
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
err = r.ParseForm()
clientID, authenticated, err := ClientIDFromRequest(r, introspector)
if err != nil {
return "", "", errors.New("unable to parse request")
return "", "", err
}
req := new(struct {
oidc.IntrospectionRequest
oidc.ClientAssertionParams
})
if !authenticated {
return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
}
req := new(oidc.IntrospectionRequest)
err = introspector.Decoder().Decode(req, r.Form)
if err != nil {
return "", "", errors.New("unable to parse request")
}
if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" {
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier(r.Context()))
if err == nil {
return req.Token, profile.Issuer, nil
}
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return "", "", errors.New("invalid basic auth header")
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return "", "", errors.New("invalid basic auth header")
}
if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
return "", "", err
}
return req.Token, clientID, nil
}
return "", "", errors.New("invalid authorization")
}

View file

@ -19,6 +19,7 @@ type Exchanger interface {
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool
AccessTokenVerifier(context.Context) AccessTokenVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
}
@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ClientCredentialsExchange(w, r, exchanger)
return
}
case string(oidc.GrantTypeDeviceCode):
if exchanger.GrantTypeDeviceCodeSupported() {
DeviceAccessToken(w, r, exchanger)
return
}
case "":
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
return