implement RFC 8628: Device authorization grant
This commit is contained in:
parent
03f71a67c2
commit
2342f208ef
29 changed files with 1968 additions and 97 deletions
61
example/client/device/device.go
Normal file
61
example/client/device/device.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||||
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
key = []byte("test1234test1234")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
clientID := os.Getenv("CLIENT_ID")
|
||||||
|
clientSecret := os.Getenv("CLIENT_SECRET")
|
||||||
|
keyPath := os.Getenv("KEY_PATH")
|
||||||
|
issuer := os.Getenv("ISSUER")
|
||||||
|
scopes := strings.Split(os.Getenv("SCOPES"), " ")
|
||||||
|
|
||||||
|
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||||
|
|
||||||
|
var options []rp.Option
|
||||||
|
if clientSecret == "" {
|
||||||
|
options = append(options, rp.WithPKCE(cookieHandler))
|
||||||
|
}
|
||||||
|
if keyPath != "" {
|
||||||
|
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatalf("error creating provider %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Info("starting device authorization flow")
|
||||||
|
resp, err := rp.DeviceAuthorization(scopes, provider)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatal(err)
|
||||||
|
}
|
||||||
|
logrus.Info("resp", resp)
|
||||||
|
fmt.Printf("\nPlease browse to %s and enter code %s\n", resp.VerificationURI, resp.UserCode)
|
||||||
|
|
||||||
|
logrus.Info("start polling")
|
||||||
|
token, err := rp.DeviceAccessToken(ctx, resp.DeviceCode, time.Duration(resp.Interval)*time.Second, provider)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatal(err)
|
||||||
|
}
|
||||||
|
logrus.Infof("successfully obtained token: %v", token)
|
||||||
|
}
|
191
example/server/exampleop/device.go
Normal file
191
example/server/exampleop/device.go
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
package exampleop
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/gorilla/securecookie"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/op"
|
||||||
|
)
|
||||||
|
|
||||||
|
type deviceAuthenticate interface {
|
||||||
|
CheckUsernamePasswordSimple(username, password string) error
|
||||||
|
op.DeviceAuthorizationStorage
|
||||||
|
}
|
||||||
|
|
||||||
|
type deviceLogin struct {
|
||||||
|
storage deviceAuthenticate
|
||||||
|
cookie *securecookie.SecureCookie
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) {
|
||||||
|
l := &deviceLogin{
|
||||||
|
storage: storage,
|
||||||
|
cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("", l.userCodeHandler)
|
||||||
|
router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler)
|
||||||
|
router.HandleFunc("/confirm", l.confirmHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderUserCode(w io.Writer, err error) {
|
||||||
|
data := struct {
|
||||||
|
Error string
|
||||||
|
}{
|
||||||
|
Error: errMsg(err),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := templates.ExecuteTemplate(w, "usercode", data); err != nil {
|
||||||
|
logrus.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderDeviceLogin(w http.ResponseWriter, userCode string, err error) {
|
||||||
|
data := &struct {
|
||||||
|
UserCode string
|
||||||
|
Error string
|
||||||
|
}{
|
||||||
|
UserCode: userCode,
|
||||||
|
Error: errMsg(err),
|
||||||
|
}
|
||||||
|
if err = templates.ExecuteTemplate(w, "device_login", data); err != nil {
|
||||||
|
logrus.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderConfirmPage(w http.ResponseWriter, username, clientID string, scopes []string) {
|
||||||
|
data := &struct {
|
||||||
|
Username string
|
||||||
|
ClientID string
|
||||||
|
Scopes []string
|
||||||
|
}{
|
||||||
|
Username: username,
|
||||||
|
ClientID: clientID,
|
||||||
|
Scopes: scopes,
|
||||||
|
}
|
||||||
|
if err := templates.ExecuteTemplate(w, "confirm_device", data); err != nil {
|
||||||
|
logrus.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *deviceLogin) userCodeHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
renderUserCode(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userCode := r.Form.Get("user_code")
|
||||||
|
if userCode == "" {
|
||||||
|
if prompt, _ := url.QueryUnescape(r.Form.Get("prompt")); prompt != "" {
|
||||||
|
err = errors.New(prompt)
|
||||||
|
}
|
||||||
|
renderUserCode(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
renderDeviceLogin(w, userCode, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectBack(w http.ResponseWriter, r *http.Request, prompt string) {
|
||||||
|
values := make(url.Values)
|
||||||
|
values.Set("prompt", url.QueryEscape(prompt))
|
||||||
|
|
||||||
|
url := url.URL{
|
||||||
|
Path: "/device",
|
||||||
|
RawQuery: values.Encode(),
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, url.String(), http.StatusSeeOther)
|
||||||
|
}
|
||||||
|
|
||||||
|
const userCodeCookieName = "user_code"
|
||||||
|
|
||||||
|
type userCodeCookie struct {
|
||||||
|
UserCode string
|
||||||
|
UserName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *deviceLogin) loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
redirectBack(w, r, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userCode := r.PostForm.Get("user_code")
|
||||||
|
if userCode == "" {
|
||||||
|
redirectBack(w, r, "missing user_code in request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
username := r.PostForm.Get("username")
|
||||||
|
if username == "" {
|
||||||
|
redirectBack(w, r, "missing username in request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
password := r.PostForm.Get("password")
|
||||||
|
if password == "" {
|
||||||
|
redirectBack(w, r, "missing password in request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.storage.CheckUsernamePasswordSimple(username, password); err != nil {
|
||||||
|
redirectBack(w, r, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state, err := d.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
|
||||||
|
if err != nil {
|
||||||
|
redirectBack(w, r, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded, err := d.cookie.Encode(userCodeCookieName, userCodeCookie{userCode, username})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: userCodeCookieName,
|
||||||
|
Value: encoded,
|
||||||
|
Path: "/",
|
||||||
|
}
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
renderConfirmPage(w, username, state.ClientID, state.Scopes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *deviceLogin) confirmHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cookie, err := r.Cookie(userCodeCookieName)
|
||||||
|
if err != nil {
|
||||||
|
redirectBack(w, r, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data := new(userCodeCookie)
|
||||||
|
if err = d.cookie.Decode(userCodeCookieName, cookie.Value, &data); err != nil {
|
||||||
|
redirectBack(w, r, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = r.ParseForm(); err != nil {
|
||||||
|
redirectBack(w, r, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
action := r.Form.Get("action")
|
||||||
|
switch action {
|
||||||
|
case "allowed":
|
||||||
|
err = d.storage.CompleteDeviceAuthorization(r.Context(), data.UserCode, data.UserName)
|
||||||
|
case "denied":
|
||||||
|
err = d.storage.DenyDeviceAuthorization(r.Context(), data.UserCode)
|
||||||
|
default:
|
||||||
|
err = errors.New("action must be one of \"allow\" or \"deny\"")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
redirectBack(w, r, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "Device authorization %s. You can now return to the device", action)
|
||||||
|
}
|
|
@ -3,45 +3,11 @@ package exampleop
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
queryAuthRequestID = "authRequestID"
|
|
||||||
)
|
|
||||||
|
|
||||||
var loginTmpl, _ = template.New("login").Parse(`
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>Login</title>
|
|
||||||
</head>
|
|
||||||
<body style="display: flex; align-items: center; justify-content: center; height: 100vh;">
|
|
||||||
<form method="POST" action="/login/username" style="height: 200px; width: 200px;">
|
|
||||||
|
|
||||||
<input type="hidden" name="id" value="{{.ID}}">
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<label for="username">Username:</label>
|
|
||||||
<input id="username" name="username" style="width: 100%">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<label for="password">Password:</label>
|
|
||||||
<input id="password" name="password" style="width: 100%">
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p style="color:red; min-height: 1rem;">{{.Error}}</p>
|
|
||||||
|
|
||||||
<button type="submit">Login</button>
|
|
||||||
</form>
|
|
||||||
</body>
|
|
||||||
</html>`)
|
|
||||||
|
|
||||||
type login struct {
|
type login struct {
|
||||||
authenticate authenticate
|
authenticate authenticate
|
||||||
router *mux.Router
|
router *mux.Router
|
||||||
|
@ -74,23 +40,19 @@ func (l *login) loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// the oidc package will pass the id of the auth request as query parameter
|
// the oidc package will pass the id of the auth request as query parameter
|
||||||
// we will use this id through the login process and therefore pass it to the login page
|
// we will use this id through the login process and therefore pass it to the login page
|
||||||
renderLogin(w, r.FormValue(queryAuthRequestID), nil)
|
renderLogin(w, r.FormValue(queryAuthRequestID), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func renderLogin(w http.ResponseWriter, id string, err error) {
|
func renderLogin(w http.ResponseWriter, id string, err error) {
|
||||||
var errMsg string
|
|
||||||
if err != nil {
|
|
||||||
errMsg = err.Error()
|
|
||||||
}
|
|
||||||
data := &struct {
|
data := &struct {
|
||||||
ID string
|
ID string
|
||||||
Error string
|
Error string
|
||||||
}{
|
}{
|
||||||
ID: id,
|
ID: id,
|
||||||
Error: errMsg,
|
Error: errMsg(err),
|
||||||
}
|
}
|
||||||
err = loginTmpl.Execute(w, data)
|
err = templates.ExecuteTemplate(w, "login", data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
|
@ -27,7 +28,8 @@ func init() {
|
||||||
|
|
||||||
type Storage interface {
|
type Storage interface {
|
||||||
op.Storage
|
op.Storage
|
||||||
CheckUsernamePassword(username, password, id string) error
|
authenticate
|
||||||
|
deviceAuthenticate
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
|
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
|
||||||
|
@ -62,6 +64,9 @@ func SetupServer(ctx context.Context, issuer string, storage Storage) *mux.Route
|
||||||
// so we will direct all calls to /login to the login UI
|
// so we will direct all calls to /login to the login UI
|
||||||
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router))
|
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router))
|
||||||
|
|
||||||
|
router.PathPrefix("/device").Subrouter()
|
||||||
|
registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter())
|
||||||
|
|
||||||
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||||
// is served on the correct path
|
// is served on the correct path
|
||||||
//
|
//
|
||||||
|
@ -99,6 +104,13 @@ func newOP(ctx context.Context, storage op.Storage, issuer string, key [32]byte)
|
||||||
|
|
||||||
// this example has only static texts (in English), so we'll set the here accordingly
|
// this example has only static texts (in English), so we'll set the here accordingly
|
||||||
SupportedUILocales: []language.Tag{language.English},
|
SupportedUILocales: []language.Tag{language.English},
|
||||||
|
|
||||||
|
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||||
|
Lifetime: 5 * time.Minute,
|
||||||
|
PollInterval: 5 * time.Second,
|
||||||
|
UserFormURL: issuer + "device",
|
||||||
|
UserCode: op.UserCodeBase20,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
handler, err := op.NewOpenIDProvider(ctx, issuer, config, storage,
|
handler, err := op.NewOpenIDProvider(ctx, issuer, config, storage,
|
||||||
//we must explicitly allow the use of the http issuer
|
//we must explicitly allow the use of the http issuer
|
||||||
|
|
26
example/server/exampleop/templates.go
Normal file
26
example/server/exampleop/templates.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package exampleop
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"html/template"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed templates
|
||||||
|
templateFS embed.FS
|
||||||
|
templates = template.Must(template.ParseFS(templateFS, "templates/*.html"))
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
queryAuthRequestID = "authRequestID"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errMsg(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
logrus.Error(err)
|
||||||
|
return err.Error()
|
||||||
|
}
|
25
example/server/exampleop/templates/confirm_device.html
Normal file
25
example/server/exampleop/templates/confirm_device.html
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
{{ define "confirm_device" -}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Confirm device authorization</title>
|
||||||
|
<style>
|
||||||
|
.green{
|
||||||
|
background-color: green
|
||||||
|
}
|
||||||
|
.red{
|
||||||
|
background-color: red
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Welcome back {{.Username}}!</h1>
|
||||||
|
<p>
|
||||||
|
You are about to grant device {{.ClientID}} access to the following scopes: {{.Scopes}}.
|
||||||
|
</p>
|
||||||
|
<button onclick="location.href='./confirm?action=allowed'" type="button" class="green">Allow</button>
|
||||||
|
<button onclick="location.href='./confirm?action=denied'" type="button" class="red">Deny</button>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{- end }}
|
29
example/server/exampleop/templates/device_login.html
Normal file
29
example/server/exampleop/templates/device_login.html
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
{{ define "device_login" -}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Login</title>
|
||||||
|
</head>
|
||||||
|
<body style="display: flex; align-items: center; justify-content: center; height: 100vh;">
|
||||||
|
<form method="POST" action="/device/login" style="height: 200px; width: 200px;">
|
||||||
|
|
||||||
|
<input type="hidden" name="user_code" value="{{.UserCode}}">
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label for="username">Username:</label>
|
||||||
|
<input id="username" name="username" style="width: 100%">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label for="password">Password:</label>
|
||||||
|
<input id="password" name="password" style="width: 100%">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p style="color:red; min-height: 1rem;">{{.Error}}</p>
|
||||||
|
|
||||||
|
<button type="submit">Login</button>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{- end }}
|
29
example/server/exampleop/templates/login.html
Normal file
29
example/server/exampleop/templates/login.html
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
{{ define "login" -}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Login</title>
|
||||||
|
</head>
|
||||||
|
<body style="display: flex; align-items: center; justify-content: center; height: 100vh;">
|
||||||
|
<form method="POST" action="/login/username" style="height: 200px; width: 200px;">
|
||||||
|
|
||||||
|
<input type="hidden" name="id" value="{{.ID}}">
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label for="username">Username:</label>
|
||||||
|
<input id="username" name="username" style="width: 100%">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label for="password">Password:</label>
|
||||||
|
<input id="password" name="password" style="width: 100%">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p style="color:red; min-height: 1rem;">{{.Error}}</p>
|
||||||
|
|
||||||
|
<button type="submit">Login</button>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
{{- end }}
|
21
example/server/exampleop/templates/usercode.html
Normal file
21
example/server/exampleop/templates/usercode.html
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
{{ define "usercode" -}}
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Device authorization</title>
|
||||||
|
</head>
|
||||||
|
<body style="display: flex; align-items: center; justify-content: center; height: 100vh;">
|
||||||
|
<form method="POST" style="height: 200px; width: 200px;">
|
||||||
|
<h1>Device authorization</h1>
|
||||||
|
<div>
|
||||||
|
<label for="user_code">Code:</label>
|
||||||
|
<input id="user_code" name="user_code" style="width: 100%">
|
||||||
|
</div>
|
||||||
|
<p style="color:red; min-height: 1rem;">{{.Error}}</p>
|
||||||
|
|
||||||
|
<button type="submit">Login</button>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
{{- end }}
|
|
@ -44,6 +44,8 @@ type Storage struct {
|
||||||
services map[string]Service
|
services map[string]Service
|
||||||
refreshTokens map[string]*RefreshToken
|
refreshTokens map[string]*RefreshToken
|
||||||
signingKey signingKey
|
signingKey signingKey
|
||||||
|
deviceCodes map[string]deviceAuthorizationEntry
|
||||||
|
userCodes map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
type signingKey struct {
|
type signingKey struct {
|
||||||
|
@ -105,6 +107,8 @@ func NewStorage(userStore UserStore) *Storage {
|
||||||
algorithm: jose.RS256,
|
algorithm: jose.RS256,
|
||||||
key: key,
|
key: key,
|
||||||
},
|
},
|
||||||
|
deviceCodes: make(map[string]deviceAuthorizationEntry),
|
||||||
|
userCodes: make(map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,6 +139,17 @@ func (s *Storage) CheckUsernamePassword(username, password, id string) error {
|
||||||
return fmt.Errorf("username or password wrong")
|
return fmt.Errorf("username or password wrong")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Storage) CheckUsernamePasswordSimple(username, password string) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
user := s.userStore.GetUserByUsername(username)
|
||||||
|
if user != nil && user.Password == password {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("username or password wrong")
|
||||||
|
}
|
||||||
|
|
||||||
// CreateAuthRequest implements the op.Storage interface
|
// CreateAuthRequest implements the op.Storage interface
|
||||||
// it will be called after parsing and validation of the authentication request
|
// it will be called after parsing and validation of the authentication request
|
||||||
func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
|
func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
|
||||||
|
@ -735,3 +750,85 @@ func appendClaim(claims map[string]interface{}, claim string, value interface{})
|
||||||
claims[claim] = value
|
claims[claim] = value
|
||||||
return claims
|
return claims
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type deviceAuthorizationEntry struct {
|
||||||
|
deviceCode string
|
||||||
|
userCode string
|
||||||
|
state *op.DeviceAuthorizationState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if _, ok := s.clients[clientID]; !ok {
|
||||||
|
return errors.New("client not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := s.userCodes[userCode]; ok {
|
||||||
|
return op.ErrDuplicateUserCode
|
||||||
|
}
|
||||||
|
|
||||||
|
s.deviceCodes[deviceCode] = deviceAuthorizationEntry{
|
||||||
|
deviceCode: deviceCode,
|
||||||
|
userCode: userCode,
|
||||||
|
state: &op.DeviceAuthorizationState{
|
||||||
|
ClientID: clientID,
|
||||||
|
Scopes: scopes,
|
||||||
|
Expires: expires,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.userCodes[userCode] = deviceCode
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
entry, ok := s.deviceCodes[deviceCode]
|
||||||
|
if !ok || entry.state.ClientID != clientID {
|
||||||
|
return nil, errors.New("device code not found for client") // is there a standard not found error in the framework?
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
entry, ok := s.deviceCodes[s.userCodes[userCode]]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("user code not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
entry, ok := s.deviceCodes[s.userCodes[userCode]]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("user code not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.state.Subject = subject
|
||||||
|
entry.state.Done = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) DenyDeviceAuthorization(ctx context.Context, userCode string) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
s.deviceCodes[s.userCodes[userCode]].state.Denied = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -186,3 +188,94 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
|
||||||
IssuedAt: oidc.Time(iat),
|
IssuedAt: oidc.Time(iat),
|
||||||
}, signer)
|
}, signer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DeviceAuthorizationCaller interface {
|
||||||
|
GetDeviceAuthorizationEndpoint() string
|
||||||
|
HttpClient() *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
|
||||||
|
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if request.ClientSecret != "" {
|
||||||
|
req.SetBasicAuth(request.ClientID, request.ClientSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := new(oidc.DeviceAuthorizationResponse)
|
||||||
|
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceAccessTokenRequest struct {
|
||||||
|
*oidc.ClientCredentialsRequest
|
||||||
|
oidc.DeviceAccessTokenRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
||||||
|
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if request.ClientSecret != "" {
|
||||||
|
req.SetBasicAuth(request.ClientID, request.ClientSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpResp, err := caller.HttpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer httpResp.Body.Close()
|
||||||
|
|
||||||
|
resp := new(struct {
|
||||||
|
*oidc.AccessTokenResponse
|
||||||
|
*oidc.Error
|
||||||
|
})
|
||||||
|
if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpResp.StatusCode == http.StatusOK {
|
||||||
|
return resp.AccessTokenResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, resp.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
||||||
|
for {
|
||||||
|
timer := time.After(interval)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-timer:
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, interval)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
interval += 5 * time.Second
|
||||||
|
}
|
||||||
|
var target *oidc.Error
|
||||||
|
if !errors.As(err, &target) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch target.ErrorType {
|
||||||
|
case oidc.AuthorizationPending:
|
||||||
|
continue
|
||||||
|
case oidc.SlowDown:
|
||||||
|
interval += 5 * time.Second
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
62
pkg/client/rp/device.go
Normal file
62
pkg/client/rp/device.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
package rp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/client"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
|
||||||
|
confg := rp.OAuthConfig()
|
||||||
|
req := &oidc.ClientCredentialsRequest{
|
||||||
|
GrantType: oidc.GrantTypeDeviceCode,
|
||||||
|
Scope: scopes,
|
||||||
|
ClientID: confg.ClientID,
|
||||||
|
ClientSecret: confg.ClientSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
if signer := rp.Signer(); signer != nil {
|
||||||
|
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to build assertion: %w", err)
|
||||||
|
}
|
||||||
|
req.ClientAssertion = assertion
|
||||||
|
req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuthorization starts a new Device Authorization flow as defined
|
||||||
|
// in RFC 8628, section 3.1 and 3.2:
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
||||||
|
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
|
||||||
|
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return client.CallDeviceAuthorizationEndpoint(req, rp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
|
||||||
|
// by means of polling as defined in RFC, section 3.3 and 3.4:
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
|
||||||
|
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
|
||||||
|
req := &client.DeviceAccessTokenRequest{
|
||||||
|
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
|
||||||
|
GrantType: oidc.GrantTypeDeviceCode,
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
|
||||||
|
}
|
|
@ -59,6 +59,10 @@ type RelyingParty interface {
|
||||||
// UserinfoEndpoint returns the userinfo
|
// UserinfoEndpoint returns the userinfo
|
||||||
UserinfoEndpoint() string
|
UserinfoEndpoint() string
|
||||||
|
|
||||||
|
// GetDeviceAuthorizationEndpoint returns the enpoint which can
|
||||||
|
// be used to start a DeviceAuthorization flow.
|
||||||
|
GetDeviceAuthorizationEndpoint() string
|
||||||
|
|
||||||
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
|
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
|
||||||
IDTokenVerifier() IDTokenVerifier
|
IDTokenVerifier() IDTokenVerifier
|
||||||
// ErrorHandler returns the handler used for callback errors
|
// ErrorHandler returns the handler used for callback errors
|
||||||
|
@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
|
||||||
return rp.endpoints.UserinfoURL
|
return rp.endpoints.UserinfoURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
|
||||||
|
return rp.endpoints.DeviceAuthorizationURL
|
||||||
|
}
|
||||||
|
|
||||||
func (rp *relyingParty) GetEndSessionEndpoint() string {
|
func (rp *relyingParty) GetEndSessionEndpoint() string {
|
||||||
return rp.endpoints.EndSessionURL
|
return rp.endpoints.EndSessionURL
|
||||||
}
|
}
|
||||||
|
@ -495,11 +503,12 @@ type OptionFunc func(RelyingParty)
|
||||||
|
|
||||||
type Endpoints struct {
|
type Endpoints struct {
|
||||||
oauth2.Endpoint
|
oauth2.Endpoint
|
||||||
IntrospectURL string
|
IntrospectURL string
|
||||||
UserinfoURL string
|
UserinfoURL string
|
||||||
JKWsURL string
|
JKWsURL string
|
||||||
EndSessionURL string
|
EndSessionURL string
|
||||||
RevokeURL string
|
RevokeURL string
|
||||||
|
DeviceAuthorizationURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||||
|
@ -509,11 +518,12 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||||
TokenURL: discoveryConfig.TokenEndpoint,
|
TokenURL: discoveryConfig.TokenEndpoint,
|
||||||
},
|
},
|
||||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||||
JKWsURL: discoveryConfig.JwksURI,
|
JKWsURL: discoveryConfig.JwksURI,
|
||||||
EndSessionURL: discoveryConfig.EndSessionEndpoint,
|
EndSessionURL: discoveryConfig.EndSessionEndpoint,
|
||||||
RevokeURL: discoveryConfig.RevocationEndpoint,
|
RevokeURL: discoveryConfig.RevocationEndpoint,
|
||||||
|
DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
29
pkg/oidc/device_authorization.go
Normal file
29
pkg/oidc/device_authorization.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package oidc
|
||||||
|
|
||||||
|
// DeviceAuthorizationRequest implements
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1,
|
||||||
|
// 3.1 Device Authorization Request.
|
||||||
|
type DeviceAuthorizationRequest struct {
|
||||||
|
Scopes SpaceDelimitedArray `schema:"scope"`
|
||||||
|
ClientID string `schema:"client_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuthorizationResponse implements
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.2
|
||||||
|
// 3.2. Device Authorization Response.
|
||||||
|
type DeviceAuthorizationResponse struct {
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
VerificationURI string `json:"verification_uri"`
|
||||||
|
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
Interval int `json:"interval,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAccessTokenRequest implements
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4,
|
||||||
|
// Device Access Token Request.
|
||||||
|
type DeviceAccessTokenRequest struct {
|
||||||
|
GrantType GrantType `json:"grant_type" schema:"grant_type"`
|
||||||
|
DeviceCode string `json:"device_code" schema:"device_code"`
|
||||||
|
}
|
|
@ -30,6 +30,8 @@ type DiscoveryConfiguration struct {
|
||||||
// EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
|
// EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
|
||||||
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
||||||
|
|
||||||
|
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"`
|
||||||
|
|
||||||
// CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
|
// CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
|
||||||
CheckSessionIframe string `json:"check_session_iframe,omitempty"`
|
CheckSessionIframe string `json:"check_session_iframe,omitempty"`
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,14 @@ const (
|
||||||
InteractionRequired errorType = "interaction_required"
|
InteractionRequired errorType = "interaction_required"
|
||||||
LoginRequired errorType = "login_required"
|
LoginRequired errorType = "login_required"
|
||||||
RequestNotSupported errorType = "request_not_supported"
|
RequestNotSupported errorType = "request_not_supported"
|
||||||
|
|
||||||
|
// Additional error codes as defined in
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.5
|
||||||
|
// Device Access Token Response
|
||||||
|
AuthorizationPending errorType = "authorization_pending"
|
||||||
|
SlowDown errorType = "slow_down"
|
||||||
|
AccessDenied errorType = "access_denied"
|
||||||
|
ExpiredToken errorType = "expired_token"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -77,6 +85,32 @@ var (
|
||||||
ErrorType: RequestNotSupported,
|
ErrorType: RequestNotSupported,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device Access Token errors:
|
||||||
|
ErrAuthorizationPending = func() *Error {
|
||||||
|
return &Error{
|
||||||
|
ErrorType: AuthorizationPending,
|
||||||
|
Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ErrSlowDown = func() *Error {
|
||||||
|
return &Error{
|
||||||
|
ErrorType: SlowDown,
|
||||||
|
Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ErrAccessDenied = func() *Error {
|
||||||
|
return &Error{
|
||||||
|
ErrorType: AccessDenied,
|
||||||
|
Description: "The authorization request was denied.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ErrExpiredDeviceCode = func() *Error {
|
||||||
|
return &Error{
|
||||||
|
ErrorType: ExpiredToken,
|
||||||
|
Description: "The \"device_code\" has expired.",
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
|
|
|
@ -27,6 +27,9 @@ const (
|
||||||
// GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code
|
// GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code
|
||||||
GrantTypeImplicit GrantType = "implicit"
|
GrantTypeImplicit GrantType = "implicit"
|
||||||
|
|
||||||
|
// GrantTypeDeviceCode
|
||||||
|
GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
|
|
||||||
// ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
|
// ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
|
||||||
// used for the OAuth JWT Profile Client Authentication
|
// used for the OAuth JWT Profile Client Authentication
|
||||||
ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||||
|
@ -35,7 +38,7 @@ const (
|
||||||
var AllGrantTypes = []GrantType{
|
var AllGrantTypes = []GrantType{
|
||||||
GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials,
|
GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials,
|
||||||
GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit,
|
GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit,
|
||||||
ClientAssertionTypeJWTAssertion,
|
GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion,
|
||||||
}
|
}
|
||||||
|
|
||||||
type GrantType string
|
type GrantType string
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -57,3 +62,95 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT
|
||||||
func IsConfidentialType(c Client) bool {
|
func IsConfidentialType(c Client) bool {
|
||||||
return c.ApplicationType() == ApplicationTypeWeb
|
return c.ApplicationType() == ApplicationTypeWeb
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidAuthHeader = errors.New("invalid basic auth header")
|
||||||
|
ErrNoClientCredentials = errors.New("no client credentials provided")
|
||||||
|
ErrMissingClientID = errors.New("client_id missing from request")
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClientJWTProfile interface {
|
||||||
|
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
|
||||||
|
if ca.ClientAssertion == "" {
|
||||||
|
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed")
|
||||||
|
}
|
||||||
|
return profile.Issuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) {
|
||||||
|
clientID, clientSecret, ok := r.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||||
|
}
|
||||||
|
clientID, err = url.QueryUnescape(clientID)
|
||||||
|
if err != nil {
|
||||||
|
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
|
||||||
|
}
|
||||||
|
clientSecret, err = url.QueryUnescape(clientSecret)
|
||||||
|
if err != nil {
|
||||||
|
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
|
||||||
|
}
|
||||||
|
if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
|
||||||
|
return "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||||
|
}
|
||||||
|
return clientID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientProvider interface {
|
||||||
|
Decoder() httphelper.Decoder
|
||||||
|
Storage() Storage
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientData struct {
|
||||||
|
ClientID string `schema:"client_id"`
|
||||||
|
oidc.ClientAssertionParams
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientIDFromRequest parses the request form and tries to obtain the client ID
|
||||||
|
// and reports if it is authenticated, using a JWT or static client secrets over
|
||||||
|
// http basic auth.
|
||||||
|
//
|
||||||
|
// If the Provider implements IntrospectorJWTProfile and "client_assertion" is
|
||||||
|
// present in the form data, JWT assertion will be verified and the
|
||||||
|
// client ID is taken from there.
|
||||||
|
// If any of them is absent, basic auth is attempted.
|
||||||
|
// In absence of basic auth data, the unauthenticated client id from the form
|
||||||
|
// data is returned.
|
||||||
|
//
|
||||||
|
// If no client id can be obtained by any method, oidc.ErrInvalidClient
|
||||||
|
// is returned with ErrMissingClientID wrapped in it.
|
||||||
|
func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) {
|
||||||
|
err = r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := new(clientData)
|
||||||
|
if err = p.Decoder().Decode(data, r.PostForm); err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
JWTProfile, ok := p.(ClientJWTProfile)
|
||||||
|
if ok {
|
||||||
|
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
|
||||||
|
}
|
||||||
|
if !ok || errors.Is(err, ErrNoClientCredentials) {
|
||||||
|
clientID, err = ClientBasicAuth(r, p.Storage())
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return clientID, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.ClientID == "" {
|
||||||
|
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
|
||||||
|
}
|
||||||
|
return data.ClientID, false, nil
|
||||||
|
}
|
||||||
|
|
253
pkg/op/client_test.go
Normal file
253
pkg/op/client_test.go
Normal file
|
@ -0,0 +1,253 @@
|
||||||
|
package op_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/gorilla/schema"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/op"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testClientJWTProfile struct{}
|
||||||
|
|
||||||
|
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
|
||||||
|
|
||||||
|
func TestClientJWTAuth(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
ca oidc.ClientAssertionParams
|
||||||
|
verifier op.ClientJWTProfile
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantClientID string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty assertion",
|
||||||
|
args: args{
|
||||||
|
context.Background(),
|
||||||
|
oidc.ClientAssertionParams{},
|
||||||
|
testClientJWTProfile{},
|
||||||
|
},
|
||||||
|
wantErr: op.ErrNoClientCredentials,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "verification error",
|
||||||
|
args: args{
|
||||||
|
context.Background(),
|
||||||
|
oidc.ClientAssertionParams{
|
||||||
|
ClientAssertion: "foo",
|
||||||
|
},
|
||||||
|
testClientJWTProfile{},
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrParse,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier)
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientBasicAuth(t *testing.T) {
|
||||||
|
errWrong := errors.New("wrong secret")
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args *args
|
||||||
|
storage op.Storage
|
||||||
|
wantClientID string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no args",
|
||||||
|
wantErr: op.ErrNoClientCredentials,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username unescape err",
|
||||||
|
args: &args{
|
||||||
|
username: "%",
|
||||||
|
password: "bar",
|
||||||
|
},
|
||||||
|
wantErr: op.ErrInvalidAuthHeader,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "password unescape err",
|
||||||
|
args: &args{
|
||||||
|
username: "foo",
|
||||||
|
password: "%",
|
||||||
|
},
|
||||||
|
wantErr: op.ErrInvalidAuthHeader,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth error",
|
||||||
|
args: &args{
|
||||||
|
username: "foo",
|
||||||
|
password: "wrong",
|
||||||
|
},
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
wantErr: errWrong,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth error",
|
||||||
|
args: &args{
|
||||||
|
username: "foo",
|
||||||
|
password: "bar",
|
||||||
|
},
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
wantClientID: "foo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
if tt.args != nil {
|
||||||
|
r.SetBasicAuth(tt.args.username, tt.args.password)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotClientID, err := op.ClientBasicAuth(r, tt.storage)
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errReader struct{}
|
||||||
|
|
||||||
|
func (errReader) Read([]byte) (int, error) {
|
||||||
|
return 0, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
|
||||||
|
type testClientProvider struct {
|
||||||
|
storage op.Storage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (testClientProvider) Decoder() httphelper.Decoder {
|
||||||
|
return schema.NewDecoder()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p testClientProvider) Storage() op.Storage {
|
||||||
|
return p.storage
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientIDFromRequest(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
body io.Reader
|
||||||
|
p op.ClientProvider
|
||||||
|
}
|
||||||
|
type basicAuth struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
basicAuth *basicAuth
|
||||||
|
wantClientID string
|
||||||
|
wantAuthenticated bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "parse error",
|
||||||
|
args: args{
|
||||||
|
body: errReader{},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthenticated",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{
|
||||||
|
"client_id": []string{"foo"},
|
||||||
|
}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: mock.NewStorage(t),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantClientID: "foo",
|
||||||
|
wantAuthenticated: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "authenticated",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
basicAuth: &basicAuth{
|
||||||
|
username: "foo",
|
||||||
|
password: "bar",
|
||||||
|
},
|
||||||
|
wantClientID: "foo",
|
||||||
|
wantAuthenticated: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing client id",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: mock.NewStorage(t),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
if tt.basicAuth != nil {
|
||||||
|
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||||
|
assert.Equal(t, tt.wantAuthenticated, gotAuthenticated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -27,6 +27,7 @@ type Configuration interface {
|
||||||
RevocationEndpoint() Endpoint
|
RevocationEndpoint() Endpoint
|
||||||
EndSessionEndpoint() Endpoint
|
EndSessionEndpoint() Endpoint
|
||||||
KeysEndpoint() Endpoint
|
KeysEndpoint() Endpoint
|
||||||
|
DeviceAuthorizationEndpoint() Endpoint
|
||||||
|
|
||||||
AuthMethodPostSupported() bool
|
AuthMethodPostSupported() bool
|
||||||
CodeMethodS256Supported() bool
|
CodeMethodS256Supported() bool
|
||||||
|
@ -36,6 +37,7 @@ type Configuration interface {
|
||||||
GrantTypeTokenExchangeSupported() bool
|
GrantTypeTokenExchangeSupported() bool
|
||||||
GrantTypeJWTAuthorizationSupported() bool
|
GrantTypeJWTAuthorizationSupported() bool
|
||||||
GrantTypeClientCredentialsSupported() bool
|
GrantTypeClientCredentialsSupported() bool
|
||||||
|
GrantTypeDeviceCodeSupported() bool
|
||||||
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
|
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
|
||||||
IntrospectionEndpointSigningAlgorithmsSupported() []string
|
IntrospectionEndpointSigningAlgorithmsSupported() []string
|
||||||
RevocationAuthMethodPrivateKeyJWTSupported() bool
|
RevocationAuthMethodPrivateKeyJWTSupported() bool
|
||||||
|
@ -44,6 +46,7 @@ type Configuration interface {
|
||||||
RequestObjectSigningAlgorithmsSupported() []string
|
RequestObjectSigningAlgorithmsSupported() []string
|
||||||
|
|
||||||
SupportedUILocales() []language.Tag
|
SupportedUILocales() []language.Tag
|
||||||
|
DeviceAuthorization() DeviceAuthorizationConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type IssuerFromRequest func(r *http.Request) string
|
type IssuerFromRequest func(r *http.Request) string
|
||||||
|
|
265
pkg/op/device.go
Normal file
265
pkg/op/device.go
Normal file
|
@ -0,0 +1,265 @@
|
||||||
|
package op
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeviceAuthorizationConfig struct {
|
||||||
|
Lifetime time.Duration
|
||||||
|
PollInterval time.Duration
|
||||||
|
UserFormURL string // the URL where the user must go to authorize the device
|
||||||
|
UserCode UserCodeConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserCodeConfig struct {
|
||||||
|
CharSet string
|
||||||
|
CharAmount int
|
||||||
|
DashInterval int
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
|
||||||
|
CharSetDigits = "0123456789"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
UserCodeBase20 = UserCodeConfig{
|
||||||
|
CharSet: CharSetBase20,
|
||||||
|
CharAmount: 8,
|
||||||
|
DashInterval: 4,
|
||||||
|
}
|
||||||
|
UserCodeDigits = UserCodeConfig{
|
||||||
|
CharSet: CharSetDigits,
|
||||||
|
CharAmount: 9,
|
||||||
|
DashInterval: 3,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := DeviceAuthorization(w, r, o); err != nil {
|
||||||
|
RequestError(w, r, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
|
||||||
|
storage, err := assertDeviceStorage(o.Storage())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := ParseDeviceCodeRequest(r, o)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := o.DeviceAuthorization()
|
||||||
|
|
||||||
|
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
expires := time.Now().Add(config.Lifetime)
|
||||||
|
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &oidc.DeviceAuthorizationResponse{
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
UserCode: userCode,
|
||||||
|
VerificationURI: config.UserFormURL,
|
||||||
|
ExpiresIn: int(config.Lifetime / time.Second),
|
||||||
|
Interval: int(config.PollInterval / time.Second),
|
||||||
|
}
|
||||||
|
|
||||||
|
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
|
||||||
|
|
||||||
|
httphelper.MarshalJSON(w, response)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
|
||||||
|
clientID, _, err := ClientIDFromRequest(r, o)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req := new(oidc.DeviceAuthorizationRequest)
|
||||||
|
if err := o.Decoder().Decode(req, r.Form); err != nil {
|
||||||
|
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
|
||||||
|
}
|
||||||
|
req.ClientID = clientID
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 16 bytes gives 128 bit of entropy.
|
||||||
|
// results in a 22 character base64 encoded string.
|
||||||
|
const RecommendedDeviceCodeBytes = 16
|
||||||
|
|
||||||
|
func NewDeviceCode(nBytes int) (string, error) {
|
||||||
|
bytes := make([]byte, nBytes)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", fmt.Errorf("%w getting entropy for device code", err)
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
|
||||||
|
var buf strings.Builder
|
||||||
|
if dashInterval > 0 {
|
||||||
|
buf.Grow(charAmount + charAmount/dashInterval - 1)
|
||||||
|
} else {
|
||||||
|
buf.Grow(charAmount)
|
||||||
|
}
|
||||||
|
|
||||||
|
max := big.NewInt(int64(len(charSet)))
|
||||||
|
|
||||||
|
for i := 0; i < charAmount; i++ {
|
||||||
|
if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
|
||||||
|
buf.WriteByte('-')
|
||||||
|
}
|
||||||
|
|
||||||
|
bi, err := rand.Int(rand.Reader, max)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("%w getting entropy for user code", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune(charSet[int(bi.Int64())])
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type deviceAccessTokenRequest struct {
|
||||||
|
subject string
|
||||||
|
audience []string
|
||||||
|
scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *deviceAccessTokenRequest) GetSubject() string {
|
||||||
|
return r.subject
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *deviceAccessTokenRequest) GetAudience() []string {
|
||||||
|
return r.audience
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *deviceAccessTokenRequest) GetScopes() []string {
|
||||||
|
return r.scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
|
if err := deviceAccessToken(w, r, exchanger); err != nil {
|
||||||
|
RequestError(w, r, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
|
||||||
|
// use a limited context timeout shorter as the default
|
||||||
|
// poll interval of 5 seconds.
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if clientAuthenticated != IsConfidentialType(client) {
|
||||||
|
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||||
|
WithDescription("confidential client requires authentication")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenRequest := &deviceAccessTokenRequest{
|
||||||
|
subject: state.Subject,
|
||||||
|
audience: []string{clientID},
|
||||||
|
scopes: state.Scopes,
|
||||||
|
}
|
||||||
|
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
httphelper.MarshalJSON(w, resp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
|
||||||
|
req := new(oidc.DeviceAccessTokenRequest)
|
||||||
|
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||||
|
storage, err := assertDeviceStorage(exchanger.Storage())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil, oidc.ErrSlowDown().WithParent(err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, oidc.ErrAccessDenied().WithParent(err)
|
||||||
|
}
|
||||||
|
if state.Denied {
|
||||||
|
return state, oidc.ErrAccessDenied()
|
||||||
|
}
|
||||||
|
if state.Done {
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
if time.Now().After(state.Expires) {
|
||||||
|
return state, oidc.ErrExpiredDeviceCode()
|
||||||
|
}
|
||||||
|
return state, oidc.ErrAuthorizationPending()
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) {
|
||||||
|
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, client, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &oidc.AccessTokenResponse{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
TokenType: oidc.BearerToken,
|
||||||
|
ExpiresIn: uint64(validity.Seconds()),
|
||||||
|
}, nil
|
||||||
|
}
|
453
pkg/op/device_test.go
Normal file
453
pkg/op/device_test.go
Normal file
|
@ -0,0 +1,453 @@
|
||||||
|
package op_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
mr "math/rand"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/op"
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testProvider op.OpenIDProvider
|
||||||
|
|
||||||
|
const (
|
||||||
|
testIssuer = "https://localhost:9998/"
|
||||||
|
pathLoggedOut = "/logged-out"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
config := &op.Config{
|
||||||
|
CryptoKey: sha256.Sum256([]byte("test")),
|
||||||
|
DefaultLogoutRedirectURI: pathLoggedOut,
|
||||||
|
CodeMethodS256: true,
|
||||||
|
AuthMethodPost: true,
|
||||||
|
AuthMethodPrivateKeyJWT: true,
|
||||||
|
GrantTypeRefreshToken: true,
|
||||||
|
RequestObjectSupported: true,
|
||||||
|
SupportedUILocales: []language.Tag{language.English},
|
||||||
|
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||||
|
Lifetime: 5 * time.Minute,
|
||||||
|
PollInterval: 5 * time.Second,
|
||||||
|
UserFormURL: testIssuer + "device",
|
||||||
|
UserCode: op.UserCodeBase20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage.RegisterClients(
|
||||||
|
storage.NativeClient("native"),
|
||||||
|
storage.WebClient("web", "secret"),
|
||||||
|
storage.WebClient("api", "secret"),
|
||||||
|
)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
testProvider, err = op.NewOpenIDProvider(context.TODO(), testIssuer, config,
|
||||||
|
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_deviceAuthorizationHandler(t *testing.T) {
|
||||||
|
req := &oidc.DeviceAuthorizationRequest{
|
||||||
|
Scopes: []string{"foo", "bar"},
|
||||||
|
ClientID: "web",
|
||||||
|
}
|
||||||
|
values := make(url.Values)
|
||||||
|
testProvider.Encoder().Encode(req, values)
|
||||||
|
body := strings.NewReader(values.Encode())
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
runWithRandReader(mr.New(mr.NewSource(1)), func() {
|
||||||
|
op.DeviceAuthorizationHandler(testProvider)(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
result := w.Result()
|
||||||
|
|
||||||
|
assert.Less(t, result.StatusCode, 300)
|
||||||
|
|
||||||
|
got, _ := io.ReadAll(result.Body)
|
||||||
|
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDeviceCodeRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req *oidc.DeviceAuthorizationRequest
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty request",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
/* decoding a SpaceDelimitedArray is broken
|
||||||
|
https://github.com/zitadel/oidc/issues/295
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
req: &oidc.DeviceAuthorizationRequest{
|
||||||
|
Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
|
||||||
|
ClientID: "web",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var body io.Reader
|
||||||
|
if tt.req != nil {
|
||||||
|
values := make(url.Values)
|
||||||
|
testProvider.Encoder().Encode(tt.req, values)
|
||||||
|
body = strings.NewReader(values.Encode())
|
||||||
|
}
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
got, err := op.ParseDeviceCodeRequest(r, testProvider)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.req, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWithRandReader(r io.Reader, f func()) {
|
||||||
|
originalReader := rand.Reader
|
||||||
|
rand.Reader = r
|
||||||
|
defer func() {
|
||||||
|
rand.Reader = originalReader
|
||||||
|
}()
|
||||||
|
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDeviceCode(t *testing.T) {
|
||||||
|
t.Run("reader error", func(t *testing.T) {
|
||||||
|
runWithRandReader(errReader{}, func() {
|
||||||
|
_, err := op.NewDeviceCode(16)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different lengths, rand reader", func(t *testing.T) {
|
||||||
|
for i := 1; i <= 32; i++ {
|
||||||
|
got, err := op.NewDeviceCode(i)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewUserCode(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
charset []rune
|
||||||
|
charAmount int
|
||||||
|
dashInterval int
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
reader io.Reader
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "reader error",
|
||||||
|
args: args{
|
||||||
|
charset: []rune(op.CharSetBase20),
|
||||||
|
charAmount: 8,
|
||||||
|
dashInterval: 4,
|
||||||
|
},
|
||||||
|
reader: errReader{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "base20",
|
||||||
|
args: args{
|
||||||
|
charset: []rune(op.CharSetBase20),
|
||||||
|
charAmount: 8,
|
||||||
|
dashInterval: 4,
|
||||||
|
},
|
||||||
|
reader: mr.New(mr.NewSource(1)),
|
||||||
|
want: "XKCD-HTTD",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "digits",
|
||||||
|
args: args{
|
||||||
|
charset: []rune(op.CharSetDigits),
|
||||||
|
charAmount: 9,
|
||||||
|
dashInterval: 3,
|
||||||
|
},
|
||||||
|
reader: mr.New(mr.NewSource(1)),
|
||||||
|
want: "271-256-225",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no dashes",
|
||||||
|
args: args{
|
||||||
|
charset: []rune(op.CharSetDigits),
|
||||||
|
charAmount: 9,
|
||||||
|
},
|
||||||
|
reader: mr.New(mr.NewSource(1)),
|
||||||
|
want: "271256225",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
runWithRandReader(tt.reader, func() {
|
||||||
|
got, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.ErrorIs(t, err, io.ErrNoProgress)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("crypto/rand", func(t *testing.T) {
|
||||||
|
const testN = 100000
|
||||||
|
|
||||||
|
for _, c := range []op.UserCodeConfig{op.UserCodeBase20, op.UserCodeDigits} {
|
||||||
|
t.Run(c.CharSet, func(t *testing.T) {
|
||||||
|
results := make(map[string]int)
|
||||||
|
|
||||||
|
for i := 0; i < testN; i++ {
|
||||||
|
code, err := op.NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval)
|
||||||
|
require.NoError(t, err)
|
||||||
|
results[code]++
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(results)
|
||||||
|
|
||||||
|
var duplicates int
|
||||||
|
for code, count := range results {
|
||||||
|
assert.Less(t, count, 3, code)
|
||||||
|
if count == 2 {
|
||||||
|
duplicates++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNewUserCode(b *testing.B) {
|
||||||
|
type args struct {
|
||||||
|
charset []rune
|
||||||
|
charAmount int
|
||||||
|
dashInterval int
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
reader io.Reader
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "math rand, base20",
|
||||||
|
args: args{
|
||||||
|
charset: []rune(op.CharSetBase20),
|
||||||
|
charAmount: 8,
|
||||||
|
dashInterval: 4,
|
||||||
|
},
|
||||||
|
reader: mr.New(mr.NewSource(1)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "math rand, digits",
|
||||||
|
args: args{
|
||||||
|
charset: []rune(op.CharSetDigits),
|
||||||
|
charAmount: 9,
|
||||||
|
dashInterval: 3,
|
||||||
|
},
|
||||||
|
reader: mr.New(mr.NewSource(1)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "crypto rand, base20",
|
||||||
|
args: args{
|
||||||
|
charset: []rune(op.CharSetBase20),
|
||||||
|
charAmount: 8,
|
||||||
|
dashInterval: 4,
|
||||||
|
},
|
||||||
|
reader: rand.Reader,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "crypto rand, digits",
|
||||||
|
args: args{
|
||||||
|
charset: []rune(op.CharSetDigits),
|
||||||
|
charAmount: 9,
|
||||||
|
dashInterval: 3,
|
||||||
|
},
|
||||||
|
reader: rand.Reader,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
runWithRandReader(tt.reader, func() {
|
||||||
|
b.Run(tt.name, func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceAccessToken(t *testing.T) {
|
||||||
|
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
|
||||||
|
storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
|
||||||
|
storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
|
||||||
|
|
||||||
|
values := make(url.Values)
|
||||||
|
values.Set("client_id", "native")
|
||||||
|
values.Set("grant_type", string(oidc.GrantTypeDeviceCode))
|
||||||
|
values.Set("device_code", "qwerty")
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode()))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
op.DeviceAccessToken(w, r, testProvider)
|
||||||
|
|
||||||
|
result := w.Result()
|
||||||
|
got, _ := io.ReadAll(result.Body)
|
||||||
|
t.Log(string(got))
|
||||||
|
assert.Less(t, result.StatusCode, 300)
|
||||||
|
assert.NotEmpty(t, string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckDeviceAuthorizationState(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
|
||||||
|
storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
|
||||||
|
storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
|
||||||
|
storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})
|
||||||
|
storage.StoreDeviceAuthorization(context.Background(), "native", "expired", "expired", now.Add(-time.Minute), []string{"foo"})
|
||||||
|
|
||||||
|
storage.DenyDeviceAuthorization(context.Background(), "denied")
|
||||||
|
storage.CompleteDeviceAuthorization(context.Background(), "completed", "tim")
|
||||||
|
|
||||||
|
exceededCtx, cancel := context.WithTimeout(context.Background(), -time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
clientID string
|
||||||
|
deviceCode string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *op.DeviceAuthorizationState
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "pending",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
clientID: "native",
|
||||||
|
deviceCode: "pending",
|
||||||
|
},
|
||||||
|
want: &op.DeviceAuthorizationState{
|
||||||
|
ClientID: "native",
|
||||||
|
Scopes: []string{"foo"},
|
||||||
|
Expires: now.Add(time.Minute),
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrAuthorizationPending(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "slow down",
|
||||||
|
args: args{
|
||||||
|
ctx: exceededCtx,
|
||||||
|
clientID: "native",
|
||||||
|
deviceCode: "ok",
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrSlowDown(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong client",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
clientID: "foo",
|
||||||
|
deviceCode: "ok",
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrAccessDenied(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denied",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
clientID: "native",
|
||||||
|
deviceCode: "denied",
|
||||||
|
},
|
||||||
|
want: &op.DeviceAuthorizationState{
|
||||||
|
ClientID: "native",
|
||||||
|
Scopes: []string{"foo"},
|
||||||
|
Expires: now.Add(time.Minute),
|
||||||
|
Denied: true,
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrAccessDenied(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completed",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
clientID: "native",
|
||||||
|
deviceCode: "completed",
|
||||||
|
},
|
||||||
|
want: &op.DeviceAuthorizationState{
|
||||||
|
ClientID: "native",
|
||||||
|
Scopes: []string{"foo"},
|
||||||
|
Expires: now.Add(time.Minute),
|
||||||
|
Subject: "tim",
|
||||||
|
Done: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
clientID: "native",
|
||||||
|
deviceCode: "expired",
|
||||||
|
},
|
||||||
|
want: &op.DeviceAuthorizationState{
|
||||||
|
ClientID: "native",
|
||||||
|
Scopes: []string{"foo"},
|
||||||
|
Expires: now.Add(-time.Minute),
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrExpiredDeviceCode(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := op.CheckDeviceAuthorizationState(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, testProvider)
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -44,6 +44,7 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
|
||||||
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
|
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
|
||||||
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
|
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
|
||||||
JwksURI: config.KeysEndpoint().Absolute(issuer),
|
JwksURI: config.KeysEndpoint().Absolute(issuer),
|
||||||
|
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
|
||||||
ScopesSupported: Scopes(config),
|
ScopesSupported: Scopes(config),
|
||||||
ResponseTypesSupported: ResponseTypes(config),
|
ResponseTypesSupported: ResponseTypes(config),
|
||||||
GrantTypesSupported: GrantTypes(config),
|
GrantTypesSupported: GrantTypes(config),
|
||||||
|
@ -92,6 +93,9 @@ func GrantTypes(c Configuration) []oidc.GrantType {
|
||||||
if c.GrantTypeJWTAuthorizationSupported() {
|
if c.GrantTypeJWTAuthorizationSupported() {
|
||||||
grantTypes = append(grantTypes, oidc.GrantTypeBearer)
|
grantTypes = append(grantTypes, oidc.GrantTypeBearer)
|
||||||
}
|
}
|
||||||
|
if c.GrantTypeDeviceCodeSupported() {
|
||||||
|
grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode)
|
||||||
|
}
|
||||||
return grantTypes
|
return grantTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -131,6 +131,7 @@ func Test_GrantTypes(t *testing.T) {
|
||||||
c.EXPECT().GrantTypeTokenExchangeSupported().Return(false)
|
c.EXPECT().GrantTypeTokenExchangeSupported().Return(false)
|
||||||
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false)
|
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false)
|
||||||
c.EXPECT().GrantTypeClientCredentialsSupported().Return(false)
|
c.EXPECT().GrantTypeClientCredentialsSupported().Return(false)
|
||||||
|
c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
|
||||||
return c
|
return c
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
|
@ -148,6 +149,7 @@ func Test_GrantTypes(t *testing.T) {
|
||||||
c.EXPECT().GrantTypeTokenExchangeSupported().Return(true)
|
c.EXPECT().GrantTypeTokenExchangeSupported().Return(true)
|
||||||
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true)
|
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true)
|
||||||
c.EXPECT().GrantTypeClientCredentialsSupported().Return(true)
|
c.EXPECT().GrantTypeClientCredentialsSupported().Return(true)
|
||||||
|
c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
|
||||||
return c
|
return c
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
|
|
|
@ -92,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeviceAuthorization mocks base method.
|
||||||
|
func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeviceAuthorization")
|
||||||
|
ret0, _ := ret[0].(op.DeviceAuthorizationConfig)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuthorization indicates an expected call of DeviceAuthorization.
|
||||||
|
func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuthorizationEndpoint mocks base method.
|
||||||
|
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
|
||||||
|
ret0, _ := ret[0].(op.Endpoint)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint.
|
||||||
|
func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint))
|
||||||
|
}
|
||||||
|
|
||||||
// EndSessionEndpoint mocks base method.
|
// EndSessionEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
|
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -120,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() *
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GrantTypeDeviceCodeSupported mocks base method.
|
||||||
|
func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported")
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported.
|
||||||
|
func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported))
|
||||||
|
}
|
||||||
|
|
||||||
// GrantTypeJWTAuthorizationSupported mocks base method.
|
// GrantTypeJWTAuthorizationSupported mocks base method.
|
||||||
func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
|
func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -358,6 +400,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserCodeFormEndpoint mocks base method.
|
||||||
|
func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "UserCodeFormEndpoint")
|
||||||
|
ret0, _ := ret[0].(op.Endpoint)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint.
|
||||||
|
func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint))
|
||||||
|
}
|
||||||
|
|
||||||
// UserinfoEndpoint mocks base method.
|
// UserinfoEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
|
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
49
pkg/op/op.go
49
pkg/op/op.go
|
@ -27,17 +27,19 @@ const (
|
||||||
defaultRevocationEndpoint = "revoke"
|
defaultRevocationEndpoint = "revoke"
|
||||||
defaultEndSessionEndpoint = "end_session"
|
defaultEndSessionEndpoint = "end_session"
|
||||||
defaultKeysEndpoint = "keys"
|
defaultKeysEndpoint = "keys"
|
||||||
|
defaultDeviceAuthzEndpoint = "/device_authorization"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
DefaultEndpoints = &endpoints{
|
DefaultEndpoints = &endpoints{
|
||||||
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
|
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
|
||||||
Token: NewEndpoint(defaultTokenEndpoint),
|
Token: NewEndpoint(defaultTokenEndpoint),
|
||||||
Introspection: NewEndpoint(defaultIntrospectEndpoint),
|
Introspection: NewEndpoint(defaultIntrospectEndpoint),
|
||||||
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
|
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
|
||||||
Revocation: NewEndpoint(defaultRevocationEndpoint),
|
Revocation: NewEndpoint(defaultRevocationEndpoint),
|
||||||
EndSession: NewEndpoint(defaultEndSessionEndpoint),
|
EndSession: NewEndpoint(defaultEndSessionEndpoint),
|
||||||
JwksURI: NewEndpoint(defaultKeysEndpoint),
|
JwksURI: NewEndpoint(defaultKeysEndpoint),
|
||||||
|
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultCORSOptions = cors.Options{
|
defaultCORSOptions = cors.Options{
|
||||||
|
@ -95,6 +97,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
|
||||||
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
|
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
|
||||||
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
|
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
|
||||||
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
|
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
|
||||||
|
router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o))
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,17 +121,19 @@ type Config struct {
|
||||||
GrantTypeRefreshToken bool
|
GrantTypeRefreshToken bool
|
||||||
RequestObjectSupported bool
|
RequestObjectSupported bool
|
||||||
SupportedUILocales []language.Tag
|
SupportedUILocales []language.Tag
|
||||||
|
DeviceAuthorization DeviceAuthorizationConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type endpoints struct {
|
type endpoints struct {
|
||||||
Authorization Endpoint
|
Authorization Endpoint
|
||||||
Token Endpoint
|
Token Endpoint
|
||||||
Introspection Endpoint
|
Introspection Endpoint
|
||||||
Userinfo Endpoint
|
Userinfo Endpoint
|
||||||
Revocation Endpoint
|
Revocation Endpoint
|
||||||
EndSession Endpoint
|
EndSession Endpoint
|
||||||
CheckSessionIframe Endpoint
|
CheckSessionIframe Endpoint
|
||||||
JwksURI Endpoint
|
JwksURI Endpoint
|
||||||
|
DeviceAuthorization Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
|
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
|
||||||
|
@ -145,6 +150,7 @@ type endpoints struct {
|
||||||
// /revoke
|
// /revoke
|
||||||
// /end_session
|
// /end_session
|
||||||
// /keys
|
// /keys
|
||||||
|
// /device_authorization
|
||||||
//
|
//
|
||||||
// This does not include login. Login is handled with a redirect that includes the
|
// This does not include login. Login is handled with a redirect that includes the
|
||||||
// request ID. The redirect for logins is specified per-client by Client.LoginURL().
|
// request ID. The redirect for logins is specified per-client by Client.LoginURL().
|
||||||
|
@ -242,6 +248,10 @@ func (o *Provider) EndSessionEndpoint() Endpoint {
|
||||||
return o.endpoints.EndSession
|
return o.endpoints.EndSession
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
|
||||||
|
return o.endpoints.DeviceAuthorization
|
||||||
|
}
|
||||||
|
|
||||||
func (o *Provider) KeysEndpoint() Endpoint {
|
func (o *Provider) KeysEndpoint() Endpoint {
|
||||||
return o.endpoints.JwksURI
|
return o.endpoints.JwksURI
|
||||||
}
|
}
|
||||||
|
@ -275,6 +285,11 @@ func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Provider) GrantTypeDeviceCodeSupported() bool {
|
||||||
|
_, ok := o.storage.(DeviceAuthorizationStorage)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
|
func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -308,6 +323,10 @@ func (o *Provider) SupportedUILocales() []language.Tag {
|
||||||
return o.config.SupportedUILocales
|
return o.config.SupportedUILocales
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig {
|
||||||
|
return o.config.DeviceAuthorization
|
||||||
|
}
|
||||||
|
|
||||||
func (o *Provider) Storage() Storage {
|
func (o *Provider) Storage() Storage {
|
||||||
return o.storage
|
return o.storage
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,3 +151,50 @@ type EndSessionRequest struct {
|
||||||
ClientID string
|
ClientID string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrDuplicateUserCode = errors.New("user code already exists")
|
||||||
|
|
||||||
|
type DeviceAuthorizationState struct {
|
||||||
|
ClientID string
|
||||||
|
Scopes []string
|
||||||
|
Expires time.Time
|
||||||
|
Done bool
|
||||||
|
Subject string
|
||||||
|
Denied bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceAuthorizationStorage interface {
|
||||||
|
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
|
||||||
|
// User code will be used by the user to complete the login flow and must be unique.
|
||||||
|
// ErrDuplicateUserCode signals the caller should try again with a new code.
|
||||||
|
//
|
||||||
|
// Note that user codes are low entropy keys and when many exist in the
|
||||||
|
// database, the change for collisions increases. Therefore implementers
|
||||||
|
// of this interface must make sure that user codes of expired authentication flows are purged,
|
||||||
|
// after some time.
|
||||||
|
StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error
|
||||||
|
|
||||||
|
// GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database.
|
||||||
|
// The method is polled untill the the authorization is eighter Completed, Expired or Denied.
|
||||||
|
GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
|
||||||
|
|
||||||
|
// GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow,
|
||||||
|
// identified by the user code.
|
||||||
|
GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error)
|
||||||
|
|
||||||
|
// CompleteDeviceAuthorization marks a device authorization entry as Completed,
|
||||||
|
// identified by userCode. The Subject is added to the state, so that
|
||||||
|
// GetDeviceAuthorizatonState can use it to create a new Access Token.
|
||||||
|
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
|
||||||
|
|
||||||
|
// DenyDeviceAuthorization marks a device authorization entry as Denied.
|
||||||
|
DenyDeviceAuthorization(ctx context.Context, userCode string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) {
|
||||||
|
storage, ok := s.(DeviceAuthorizationStorage)
|
||||||
|
if !ok {
|
||||||
|
return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported")
|
||||||
|
}
|
||||||
|
return storage, nil
|
||||||
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
@ -50,38 +49,19 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
|
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
|
||||||
err = r.ParseForm()
|
clientID, authenticated, err := ClientIDFromRequest(r, introspector)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", errors.New("unable to parse request")
|
return "", "", err
|
||||||
}
|
}
|
||||||
req := new(struct {
|
if !authenticated {
|
||||||
oidc.IntrospectionRequest
|
return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||||
oidc.ClientAssertionParams
|
}
|
||||||
})
|
|
||||||
|
req := new(oidc.IntrospectionRequest)
|
||||||
err = introspector.Decoder().Decode(req, r.Form)
|
err = introspector.Decoder().Decode(req, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", errors.New("unable to parse request")
|
return "", "", errors.New("unable to parse request")
|
||||||
}
|
}
|
||||||
if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" {
|
|
||||||
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier(r.Context()))
|
return req.Token, clientID, nil
|
||||||
if err == nil {
|
|
||||||
return req.Token, profile.Issuer, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
clientID, clientSecret, ok := r.BasicAuth()
|
|
||||||
if ok {
|
|
||||||
clientID, err = url.QueryUnescape(clientID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", errors.New("invalid basic auth header")
|
|
||||||
}
|
|
||||||
clientSecret, err = url.QueryUnescape(clientSecret)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", errors.New("invalid basic auth header")
|
|
||||||
}
|
|
||||||
if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
return req.Token, clientID, nil
|
|
||||||
}
|
|
||||||
return "", "", errors.New("invalid authorization")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ type Exchanger interface {
|
||||||
GrantTypeTokenExchangeSupported() bool
|
GrantTypeTokenExchangeSupported() bool
|
||||||
GrantTypeJWTAuthorizationSupported() bool
|
GrantTypeJWTAuthorizationSupported() bool
|
||||||
GrantTypeClientCredentialsSupported() bool
|
GrantTypeClientCredentialsSupported() bool
|
||||||
|
GrantTypeDeviceCodeSupported() bool
|
||||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||||
}
|
}
|
||||||
|
@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
ClientCredentialsExchange(w, r, exchanger)
|
ClientCredentialsExchange(w, r, exchanger)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case string(oidc.GrantTypeDeviceCode):
|
||||||
|
if exchanger.GrantTypeDeviceCodeSupported() {
|
||||||
|
DeviceAccessToken(w, r, exchanger)
|
||||||
|
return
|
||||||
|
}
|
||||||
case "":
|
case "":
|
||||||
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
|
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
|
||||||
return
|
return
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue