diff --git a/example/server/login.go b/example/server/exampleop/login.go similarity index 91% rename from example/server/login.go rename to example/server/exampleop/login.go index 90d01d8..fd3dead 100644 --- a/example/server/login.go +++ b/example/server/exampleop/login.go @@ -1,4 +1,4 @@ -package main +package exampleop import ( "fmt" @@ -12,8 +12,7 @@ const ( queryAuthRequestID = "authRequestID" ) -var ( - loginTmpl, _ = template.New("login").Parse(` +var loginTmpl, _ = template.New("login").Parse(` @@ -41,7 +40,6 @@ var ( `) -) type login struct { authenticate authenticate @@ -74,8 +72,8 @@ func (l *login) loginHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("cannot parse form:%s", err), http.StatusInternalServerError) return } - //the oidc package will pass the id of the auth request as query parameter - //we will use this id through the login process and therefore pass it to the login page + // the oidc package will pass the id of the auth request as query parameter + // we will use this id through the login process and therefore pass it to the login page renderLogin(w, r.FormValue(queryAuthRequestID), nil) } diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go new file mode 100644 index 0000000..4794d8a --- /dev/null +++ b/example/server/exampleop/op.go @@ -0,0 +1,116 @@ +package exampleop + +import ( + "context" + "crypto/sha256" + "log" + "net/http" + "os" + + "github.com/gorilla/mux" + "golang.org/x/text/language" + + "github.com/zitadel/oidc/example/server/storage" + "github.com/zitadel/oidc/pkg/op" +) + +const ( + pathLoggedOut = "/logged-out" +) + +func init() { + storage.RegisterClients( + storage.NativeClient("native"), + storage.WebClient("web", "secret"), + storage.WebClient("api", "secret"), + ) +} + +type Storage interface { + op.Storage + CheckUsernamePassword(username, password, id string) error +} + +// SetupServer creates an OIDC server with Issuer=http://localhost: +// +// Use one of the pre-made clients in storage/clients.go or register a new one. +func SetupServer(ctx context.Context, issuer string, storage Storage) *mux.Router { + // this will allow us to use an issuer with http:// instead of https:// + os.Setenv(op.OidcDevMode, "true") + + // the OpenID Provider requires a 32-byte key for (token) encryption + // be sure to create a proper crypto random key and manage it securely! + key := sha256.Sum256([]byte("test")) + + router := mux.NewRouter() + + // for simplicity, we provide a very small default page for users who have signed out + router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { + _, err := w.Write([]byte("signed out successfully")) + if err != nil { + log.Printf("error serving logged out page: %v", err) + } + }) + + // creation of the OpenIDProvider with the just created in-memory Storage + provider, err := newOP(ctx, storage, issuer, key) + if err != nil { + log.Fatal(err) + } + + // the provider will only take care of the OpenID Protocol, so there must be some sort of UI for the login process + // for the simplicity of the example this means a simple page with username and password field + l := NewLogin(storage, op.AuthCallbackURL(provider)) + + // regardless of how many pages / steps there are in the process, the UI must be registered in the router, + // so we will direct all calls to /login to the login UI + router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + + // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) + // is served on the correct path + // + // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), + // then you would have to set the path prefix (/custom/path/) + router.PathPrefix("/").Handler(provider.HttpHandler()) + + return router +} + +// newOP will create an OpenID Provider for localhost on a specified port with a given encryption key +// and a predefined default logout uri +// it will enable all options (see descriptions) +func newOP(ctx context.Context, storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, error) { + config := &op.Config{ + Issuer: issuer, + CryptoKey: key, + + // will be used if the end_session endpoint is called without a post_logout_redirect_uri + DefaultLogoutRedirectURI: pathLoggedOut, + + // enables code_challenge_method S256 for PKCE (and therefore PKCE in general) + CodeMethodS256: true, + + // enables additional client_id/client_secret authentication by form post (not only HTTP Basic Auth) + AuthMethodPost: true, + + // enables additional authentication by using private_key_jwt + AuthMethodPrivateKeyJWT: true, + + // enables refresh_token grant use + GrantTypeRefreshToken: true, + + // enables use of the `request` Object parameter + RequestObjectSupported: true, + + // this example has only static texts (in English), so we'll set the here accordingly + SupportedUILocales: []language.Tag{language.English}, + } + handler, err := op.NewOpenIDProvider(ctx, config, storage, + // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth + op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), + ) + if err != nil { + return nil, err + } + return handler, nil +} diff --git a/example/server/internal/storage.go b/example/server/internal/storage.go deleted file mode 100644 index 5fd61c5..0000000 --- a/example/server/internal/storage.go +++ /dev/null @@ -1,553 +0,0 @@ -package internal - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "fmt" - "math/big" - "time" - - "github.com/google/uuid" - "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" - - "github.com/zitadel/oidc/pkg/oidc" - "github.com/zitadel/oidc/pkg/op" -) - -var ( - //serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant - //the corresponding private key is in the service-key1.json (for demonstration purposes) - serviceKey1 = &rsa.PublicKey{ - N: func() *big.Int { - n, _ := new(big.Int).SetString("00f6d44fb5f34ac2033a75e73cb65ff24e6181edc58845e75a560ac21378284977bb055b1a75b714874e2a2641806205681c09abec76efd52cf40984edcf4c8ca09717355d11ac338f280d3e4c905b00543bdb8ee5a417496cb50cb0e29afc5a0d0471fd5a2fa625bd5281f61e6b02067d4fe7a5349eeae6d6a4300bcd86eef331", 16) - return n - }(), - E: 65537, - } -) - -//storage implements the op.Storage interface -//typically you would implement this as a layer on top of your database -//for simplicity this example keeps everything in-memory -type storage struct { - authRequests map[string]*AuthRequest - codes map[string]string - tokens map[string]*Token - clients map[string]*Client - users map[string]*User - services map[string]Service - refreshTokens map[string]*RefreshToken - signingKey signingKey -} - -type signingKey struct { - ID string - Algorithm string - Key *rsa.PrivateKey -} - -func NewStorage() *storage { - key, _ := rsa.GenerateKey(rand.Reader, 2048) - return &storage{ - authRequests: make(map[string]*AuthRequest), - codes: make(map[string]string), - tokens: make(map[string]*Token), - refreshTokens: make(map[string]*RefreshToken), - clients: clients, - users: map[string]*User{ - "id1": { - id: "id1", - username: "test-user", - password: "verysecure", - firstname: "Test", - lastname: "User", - email: "test-user@zitadel.ch", - emailVerified: true, - phone: "", - phoneVerified: false, - preferredLanguage: language.German, - }, - }, - services: map[string]Service{ - "service": { - keys: map[string]*rsa.PublicKey{ - "key1": serviceKey1, - }, - }, - }, - signingKey: signingKey{ - ID: "id", - Algorithm: "RS256", - Key: key, - }, - } -} - -//CheckUsernamePassword implements the `authenticate` interface of the login -func (s *storage) CheckUsernamePassword(username, password, id string) error { - request, ok := s.authRequests[id] - if !ok { - return fmt.Errorf("request not found") - } - - //for demonstration purposes we'll check on a static list with plain text password - //for real world scenarios, be sure to have the password hashed and salted (e.g. using bcrypt) - for _, user := range s.users { - if user.username == username && user.password == password { - //be sure to set user id into the auth request after the user was checked, - //so that you'll be able to get more information about the user after the login - request.UserID = user.id - - //you will have to change some state on the request to guide the user through possible multiple steps of the login process - //in this example we'll simply check the username / password and set a boolean to true - //therefore we will also just check this boolean if the request / login has been finished - request.passwordChecked = true - return nil - } - } - return fmt.Errorf("username or password wrong") -} - -//CreateAuthRequest implements the op.Storage interface -//it will be called after parsing and validation of the authentication request -func (s *storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) { - //typically, you'll fill your internal / storage model with the information of the passed object - request := authRequestToInternal(authReq, userID) - - //you'll also have to create a unique id for the request (this might be done by your database; we'll use a uuid) - request.ID = uuid.NewString() - - //and save it in your database (for demonstration purposed we will use a simple map) - s.authRequests[request.ID] = request - - //finally, return the request (which implements the AuthRequest interface of the OP - return request, nil -} - -//AuthRequestByID implements the op.Storage interface -//it will be called after the Login UI redirects back to the OIDC endpoint -func (s *storage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) { - request, ok := s.authRequests[id] - if !ok { - return nil, fmt.Errorf("request not found") - } - return request, nil -} - -//AuthRequestByCode implements the op.Storage interface -//it will be called after parsing and validation of the token request (in an authorization code flow) -func (s *storage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) { - //for this example we read the id by code and then get the request by id - requestID, ok := s.codes[code] - if !ok { - return nil, fmt.Errorf("code invalid or expired") - } - return s.AuthRequestByID(ctx, requestID) -} - -//SaveAuthCode implements the op.Storage interface -//it will be called after the authentication has been successful and before redirecting the user agent to the redirect_uri -//(in an authorization code flow) -func (s *storage) SaveAuthCode(ctx context.Context, id string, code string) error { - //for this example we'll just save the authRequestID to the code - s.codes[code] = id - return nil -} - -//DeleteAuthRequest implements the op.Storage interface -//it will be called after creating the token response (id and access tokens) for a valid -//- authentication request (in an implicit flow) -//- token request (in an authorization code flow) -func (s *storage) DeleteAuthRequest(ctx context.Context, id string) error { - //you can simply delete all reference to the auth request - delete(s.authRequests, id) - for code, requestID := range s.codes { - if id == requestID { - delete(s.codes, code) - return nil - } - } - return nil -} - -//CreateAccessToken implements the op.Storage interface -//it will be called for all requests able to return an access token (Authorization Code Flow, Implicit Flow, JWT Profile, ...) -func (s *storage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) { - var applicationID string - //if authenticated for an app (auth code / implicit flow) we must save the client_id to the token - authReq, ok := request.(*AuthRequest) - if ok { - applicationID = authReq.ApplicationID - } - token, err := s.accessToken(applicationID, "", request.GetSubject(), request.GetAudience(), request.GetScopes()) - if err != nil { - return "", time.Time{}, err - } - return token.ID, token.Expiration, nil -} - -//CreateAccessAndRefreshTokens implements the op.Storage interface -//it will be called for all requests able to return an access and refresh token (Authorization Code Flow, Refresh Token Request) -func (s *storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { - //get the information depending on the request type / implementation - applicationID, authTime, amr := getInfoFromRequest(request) - - //if currentRefreshToken is empty (Code Flow) we will have to create a new refresh token - if currentRefreshToken == "" { - refreshTokenID := uuid.NewString() - accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes()) - if err != nil { - return "", "", time.Time{}, err - } - refreshToken, err := s.createRefreshToken(accessToken, amr, authTime) - if err != nil { - return "", "", time.Time{}, err - } - return accessToken.ID, refreshToken, accessToken.Expiration, nil - } - - //if we get here, the currentRefreshToken was not empty, so the call is a refresh token request - //we therefore will have to check the currentRefreshToken and renew the refresh token - refreshToken, refreshTokenID, err := s.renewRefreshToken(currentRefreshToken) - if err != nil { - return "", "", time.Time{}, err - } - accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes()) - if err != nil { - return "", "", time.Time{}, err - } - return accessToken.ID, refreshToken, accessToken.Expiration, nil -} - -//TokenRequestByRefreshToken implements the op.Storage interface -//it will be called after parsing and validation of the refresh token request -func (s *storage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { - token, ok := s.refreshTokens[refreshToken] - if !ok { - return nil, fmt.Errorf("invalid refresh_token") - } - return RefreshTokenRequestFromBusiness(token), nil -} - -//TerminateSession implements the op.Storage interface -//it will be called after the user signed out, therefore the access and refresh token of the user of this client must be removed -func (s *storage) TerminateSession(ctx context.Context, userID string, clientID string) error { - for _, token := range s.tokens { - if token.ApplicationID == clientID && token.Subject == userID { - delete(s.tokens, token.ID) - delete(s.refreshTokens, token.RefreshTokenID) - return nil - } - } - return nil -} - -//RevokeToken implements the op.Storage interface -//it will be called after parsing and validation of the token revocation request -func (s *storage) RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error { - //a single token was requested to be removed - accessToken, ok := s.tokens[token] - if ok { - if accessToken.ApplicationID != clientID { - return oidc.ErrInvalidClient().WithDescription("token was not issued for this client") - } - //if it is an access token, just remove it - //you could also remove the corresponding refresh token if really necessary - delete(s.tokens, accessToken.ID) - return nil - } - refreshToken, ok := s.refreshTokens[token] - if !ok { - //if the token is neither an access nor a refresh token, just ignore it, the expected behaviour of - //being not valid (anymore) is achieved - return nil - } - if refreshToken.ApplicationID != clientID { - return oidc.ErrInvalidClient().WithDescription("token was not issued for this client") - } - //if it is a refresh token, you will have to remove the access token as well - delete(s.refreshTokens, refreshToken.ID) - for _, accessToken := range s.tokens { - if accessToken.RefreshTokenID == refreshToken.ID { - delete(s.tokens, accessToken.ID) - return nil - } - } - return nil -} - -//GetSigningKey implements the op.Storage interface -//it will be called when creating the OpenID Provider -func (s *storage) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey) { - //in this example the signing key is a static rsa.PrivateKey and the algorithm used is RS256 - //you would obviously have a more complex implementation and store / retrieve the key from your database as well - // - //the idea of the signing key channel is, that you can (with what ever mechanism) rotate your signing key and - //switch the key of the signer via this channel - keyCh <- jose.SigningKey{ - Algorithm: jose.SignatureAlgorithm(s.signingKey.Algorithm), //always tell the signer with algorithm to use - Key: jose.JSONWebKey{ - KeyID: s.signingKey.ID, //always give the key an id so, that it will include it in the token header as `kid` claim - Key: s.signingKey.Key, - }, - } -} - -//GetKeySet implements the op.Storage interface -//it will be called to get the current (public) keys, among others for the keys_endpoint or for validating access_tokens on the userinfo_endpoint, ... -func (s *storage) GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) { - //as mentioned above, this example only has a single signing key without key rotation, - //so it will directly use its public key - // - //when using key rotation you typically would store the public keys alongside the private keys in your database - //and give both of them an expiration date, with the public key having a longer lifetime (e.g. rotate private key every - return &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ - { - KeyID: s.signingKey.ID, - Algorithm: s.signingKey.Algorithm, - Use: oidc.KeyUseSignature, - Key: &s.signingKey.Key.PublicKey, - }}, - }, nil -} - -//GetClientByClientID implements the op.Storage interface -//it will be called whenever information (type, redirect_uris, ...) about the client behind the client_id is needed -func (s *storage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) { - client, ok := s.clients[clientID] - if !ok { - return nil, fmt.Errorf("client not found") - } - return client, nil -} - -//AuthorizeClientIDSecret implements the op.Storage interface -//it will be called for validating the client_id, client_secret on token or introspection requests -func (s *storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error { - client, ok := s.clients[clientID] - if !ok { - return fmt.Errorf("client not found") - } - //for this example we directly check the secret - //obviously you would not have the secret in plain text, but rather hashed and salted (e.g. using bcrypt) - if client.secret != clientSecret { - return fmt.Errorf("invalid secret") - } - return nil -} - -//SetUserinfoFromScopes implements the op.Storage interface -//it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check -func (s *storage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, userID, clientID string, scopes []string) error { - return s.setUserinfo(ctx, userinfo, userID, clientID, scopes) -} - -//SetUserinfoFromToken implements the op.Storage interface -//it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function -func (s *storage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, tokenID, subject, origin string) error { - token, ok := s.tokens[tokenID] - if !ok { - return fmt.Errorf("token is invalid or has expired") - } - //the userinfo endpoint should support CORS. If it's not possible to specify a specific origin in the CORS handler, - //and you have to specify a wildcard (*) origin, then you could also check here if the origin which called the userinfo endpoint here directly - //note that the origin can be empty (if called by a web client) - // - //if origin != "" { - // client, ok := s.clients[token.ApplicationID] - // if !ok { - // return fmt.Errorf("client not found") - // } - // if err := checkAllowedOrigins(client.allowedOrigins, origin); err != nil { - // return err - // } - //} - return s.setUserinfo(ctx, userinfo, token.Subject, token.ApplicationID, token.Scopes) -} - -//SetIntrospectionFromToken implements the op.Storage interface -//it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function -func (s *storage) SetIntrospectionFromToken(ctx context.Context, introspection oidc.IntrospectionResponse, tokenID, subject, clientID string) error { - token, ok := s.tokens[tokenID] - if !ok { - return fmt.Errorf("token is invalid or has expired") - } - //check if the client is part of the requested audience - for _, aud := range token.Audience { - if aud == clientID { - //the introspection response only has to return a boolean (active) if the token is active - //this will automatically be done by the library if you don't return an error - //you can also return further information about the user / associated token - //e.g. the userinfo (equivalent to userinfo endpoint) - err := s.setUserinfo(ctx, introspection, subject, clientID, token.Scopes) - if err != nil { - return err - } - //...and also the requested scopes... - introspection.SetScopes(token.Scopes) - //...and the client the token was issued to - introspection.SetClientID(token.ApplicationID) - return nil - } - } - return fmt.Errorf("token is not valid for this client") -} - -//GetPrivateClaimsFromScopes implements the op.Storage interface -//it will be called for the creation of a JWT access token to assert claims for custom scopes -func (s *storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) { - for _, scope := range scopes { - switch scope { - case CustomScope: - claims = appendClaim(claims, CustomClaim, customClaim(clientID)) - } - } - return claims, nil -} - -//GetKeyByIDAndUserID implements the op.Storage interface -//it will be called to validate the signatures of a JWT (JWT Profile Grant and Authentication) -func (s *storage) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) { - service, ok := s.services[userID] - if !ok { - return nil, fmt.Errorf("user not found") - } - key, ok := service.keys[keyID] - if !ok { - return nil, fmt.Errorf("key not found") - } - return &jose.JSONWebKey{ - KeyID: keyID, - Use: "sig", - Key: key, - }, nil -} - -//ValidateJWTProfileScopes implements the op.Storage interface -//it will be called to validate the scopes of a JWT Profile Authorization Grant request -func (s *storage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) { - allowedScopes := make([]string, 0) - for _, scope := range scopes { - if scope == oidc.ScopeOpenID { - allowedScopes = append(allowedScopes, scope) - } - } - return allowedScopes, nil -} - -//Health implements the op.Storage interface -func (s *storage) Health(ctx context.Context) error { - return nil -} - -//createRefreshToken will store a refresh_token in-memory based on the provided information -func (s *storage) createRefreshToken(accessToken *Token, amr []string, authTime time.Time) (string, error) { - token := &RefreshToken{ - ID: accessToken.RefreshTokenID, - Token: accessToken.RefreshTokenID, - AuthTime: authTime, - AMR: amr, - ApplicationID: accessToken.ApplicationID, - UserID: accessToken.Subject, - Audience: accessToken.Audience, - Expiration: time.Now().Add(5 * time.Hour), - Scopes: accessToken.Scopes, - } - s.refreshTokens[token.ID] = token - return token.Token, nil -} - -//renewRefreshToken checks the provided refresh_token and creates a new one based on the current -func (s *storage) renewRefreshToken(currentRefreshToken string) (string, string, error) { - refreshToken, ok := s.refreshTokens[currentRefreshToken] - if !ok { - return "", "", fmt.Errorf("invalid refresh token") - } - //deletes the refresh token and all access tokens which were issued based on this refresh token - delete(s.refreshTokens, currentRefreshToken) - for _, token := range s.tokens { - if token.RefreshTokenID == currentRefreshToken { - delete(s.tokens, token.ID) - break - } - } - //creates a new refresh token based on the current one - token := uuid.NewString() - refreshToken.Token = token - s.refreshTokens[token] = refreshToken - return token, refreshToken.ID, nil -} - -//accessToken will store an access_token in-memory based on the provided information -func (s *storage) accessToken(applicationID, refreshTokenID, subject string, audience, scopes []string) (*Token, error) { - token := &Token{ - ID: uuid.NewString(), - ApplicationID: applicationID, - RefreshTokenID: refreshTokenID, - Subject: subject, - Audience: audience, - Expiration: time.Now().Add(5 * time.Minute), - Scopes: scopes, - } - s.tokens[token.ID] = token - return token, nil -} - -//setUserinfo sets the info based on the user, scopes and if necessary the clientID -func (s *storage) setUserinfo(ctx context.Context, userInfo oidc.UserInfoSetter, userID, clientID string, scopes []string) (err error) { - user, ok := s.users[userID] - if !ok { - return fmt.Errorf("user not found") - } - for _, scope := range scopes { - switch scope { - case oidc.ScopeOpenID: - userInfo.SetSubject(user.id) - case oidc.ScopeEmail: - userInfo.SetEmail(user.email, user.emailVerified) - case oidc.ScopeProfile: - userInfo.SetPreferredUsername(user.username) - userInfo.SetName(user.firstname + " " + user.lastname) - userInfo.SetFamilyName(user.lastname) - userInfo.SetGivenName(user.firstname) - userInfo.SetLocale(user.preferredLanguage) - case oidc.ScopePhone: - userInfo.SetPhone(user.phone, user.phoneVerified) - case CustomScope: - //you can also have a custom scope and assert public or custom claims based on that - userInfo.AppendClaims(CustomClaim, customClaim(clientID)) - } - } - return nil -} - -//getInfoFromRequest returns the clientID, authTime and amr depending on the op.TokenRequest type / implementation -func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) { - authReq, ok := req.(*AuthRequest) //Code Flow (with scope offline_access) - if ok { - return authReq.ApplicationID, authReq.authTime, authReq.GetAMR() - } - refreshReq, ok := req.(*RefreshTokenRequest) //Refresh Token Request - if ok { - return refreshReq.ApplicationID, refreshReq.AuthTime, refreshReq.AMR - } - return "", time.Time{}, nil -} - -//customClaim demonstrates how to return custom claims based on provided information -func customClaim(clientID string) map[string]interface{} { - return map[string]interface{}{ - "client": clientID, - "other": "stuff", - } -} - -func appendClaim(claims map[string]interface{}, claim string, value interface{}) map[string]interface{} { - if claims == nil { - claims = make(map[string]interface{}) - } - claims[claim] = value - return claims -} diff --git a/example/server/internal/user.go b/example/server/internal/user.go deleted file mode 100644 index 19b5d1f..0000000 --- a/example/server/internal/user.go +++ /dev/null @@ -1,24 +0,0 @@ -package internal - -import ( - "crypto/rsa" - - "golang.org/x/text/language" -) - -type User struct { - id string - username string - password string - firstname string - lastname string - email string - emailVerified bool - phone string - phoneVerified bool - preferredLanguage language.Tag -} - -type Service struct { - keys map[string]*rsa.PublicKey -} diff --git a/example/server/main.go b/example/server/main.go new file mode 100644 index 0000000..37fbcb3 --- /dev/null +++ b/example/server/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "log" + "net/http" + + "github.com/zitadel/oidc/example/server/exampleop" + "github.com/zitadel/oidc/example/server/storage" +) + +func main() { + ctx := context.Background() + + // the OpenIDProvider interface needs a Storage interface handling various checks and state manipulations + // this might be the layer for accessing your database + // in this example it will be handled in-memory + storage := storage.NewStorage(storage.NewUserStore()) + + port := "9998" + router := exampleop.SetupServer(ctx, "http://localhost:"+port, storage) + + server := &http.Server{ + Addr: ":" + port, + Handler: router, + } + err := server.ListenAndServe() + if err != nil { + log.Fatal(err) + } + <-ctx.Done() +} diff --git a/example/server/op.go b/example/server/op.go deleted file mode 100644 index d689247..0000000 --- a/example/server/op.go +++ /dev/null @@ -1,126 +0,0 @@ -package main - -import ( - "context" - "crypto/sha256" - "fmt" - "log" - "net/http" - "os" - - "github.com/gorilla/mux" - "golang.org/x/text/language" - - "github.com/zitadel/oidc/example/server/internal" - "github.com/zitadel/oidc/pkg/op" -) - -const ( - pathLoggedOut = "/logged-out" -) - -func init() { - internal.RegisterClients( - internal.NativeClient("native"), - internal.WebClient("web", "secret"), - internal.WebClient("api", "secret"), - ) -} - -func main() { - ctx := context.Background() - - //this will allow us to use an issuer with http:// instead of https:// - os.Setenv(op.OidcDevMode, "true") - - port := "9998" - - //the OpenID Provider requires a 32-byte key for (token) encryption - //be sure to create a proper crypto random key and manage it securely! - key := sha256.Sum256([]byte("test")) - - router := mux.NewRouter() - - //for simplicity, we provide a very small default page for users who have signed out - router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { - _, err := w.Write([]byte("signed out successfully")) - if err != nil { - log.Printf("error serving logged out page: %v", err) - } - }) - - //the OpenIDProvider interface needs a Storage interface handling various checks and state manipulations - //this might be the layer for accessing your database - //in this example it will be handled in-memory - storage := internal.NewStorage() - - //creation of the OpenIDProvider with the just created in-memory Storage - provider, err := newOP(ctx, storage, port, key) - if err != nil { - log.Fatal(err) - } - - //the provider will only take care of the OpenID Protocol, so there must be some sort of UI for the login process - //for the simplicity of the example this means a simple page with username and password field - l := NewLogin(storage, op.AuthCallbackURL(provider)) - - //regardless of how many pages / steps there are in the process, the UI must be registered in the router, - //so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) - - //we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) - //is served on the correct path - // - //if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), - //then you would have to set the path prefix (/custom/path/) - router.PathPrefix("/").Handler(provider.HttpHandler()) - - server := &http.Server{ - Addr: ":" + port, - Handler: router, - } - err = server.ListenAndServe() - if err != nil { - log.Fatal(err) - } - <-ctx.Done() -} - -//newOP will create an OpenID Provider for localhost on a specified port with a given encryption key -//and a predefined default logout uri -//it will enable all options (see descriptions) -func newOP(ctx context.Context, storage op.Storage, port string, key [32]byte) (op.OpenIDProvider, error) { - config := &op.Config{ - Issuer: fmt.Sprintf("http://localhost:%s/", port), - CryptoKey: key, - - //will be used if the end_session endpoint is called without a post_logout_redirect_uri - DefaultLogoutRedirectURI: pathLoggedOut, - - //enables code_challenge_method S256 for PKCE (and therefore PKCE in general) - CodeMethodS256: true, - - //enables additional client_id/client_secret authentication by form post (not only HTTP Basic Auth) - AuthMethodPost: true, - - //enables additional authentication by using private_key_jwt - AuthMethodPrivateKeyJWT: true, - - //enables refresh_token grant use - GrantTypeRefreshToken: true, - - //enables use of the `request` Object parameter - RequestObjectSupported: true, - - //this example has only static texts (in English), so we'll set the here accordingly - SupportedUILocales: []language.Tag{language.English}, - } - handler, err := op.NewOpenIDProvider(ctx, config, storage, - //as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth - op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), - ) - if err != nil { - return nil, err - } - return handler, nil -} diff --git a/example/server/internal/client.go b/example/server/storage/client.go similarity index 60% rename from example/server/internal/client.go rename to example/server/storage/client.go index 9080e8c..0f3a703 100644 --- a/example/server/internal/client.go +++ b/example/server/storage/client.go @@ -1,4 +1,4 @@ -package internal +package storage import ( "time" @@ -8,17 +8,17 @@ import ( ) var ( - //we use the default login UI and pass the (auth request) id + // we use the default login UI and pass the (auth request) id defaultLoginURL = func(id string) string { return "/login/username?authRequestID=" + id } - //clients to be used by the storage interface + // clients to be used by the storage interface clients = map[string]*Client{} ) -//Client represents the internal model of an OAuth/OIDC client -//this could also be your database model +// Client represents the storage model of an OAuth/OIDC client +// this could also be your database model type Client struct { id string secret string @@ -34,108 +34,111 @@ type Client struct { clockSkew time.Duration } -//GetID must return the client_id +// GetID must return the client_id func (c *Client) GetID() string { return c.id } -//RedirectURIs must return the registered redirect_uris for Code and Implicit Flow +// RedirectURIs must return the registered redirect_uris for Code and Implicit Flow func (c *Client) RedirectURIs() []string { return c.redirectURIs } -//PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs +// PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs func (c *Client) PostLogoutRedirectURIs() []string { return []string{} } -//ApplicationType must return the type of the client (app, native, user agent) +// ApplicationType must return the type of the client (app, native, user agent) func (c *Client) ApplicationType() op.ApplicationType { return c.applicationType } -//AuthMethod must return the authentication method (client_secret_basic, client_secret_post, none, private_key_jwt) +// AuthMethod must return the authentication method (client_secret_basic, client_secret_post, none, private_key_jwt) func (c *Client) AuthMethod() oidc.AuthMethod { return c.authMethod } -//ResponseTypes must return all allowed response types (code, id_token token, id_token) -//these must match with the allowed grant types +// ResponseTypes must return all allowed response types (code, id_token token, id_token) +// these must match with the allowed grant types func (c *Client) ResponseTypes() []oidc.ResponseType { return c.responseTypes } -//GrantTypes must return all allowed grant types (authorization_code, refresh_token, urn:ietf:params:oauth:grant-type:jwt-bearer) +// GrantTypes must return all allowed grant types (authorization_code, refresh_token, urn:ietf:params:oauth:grant-type:jwt-bearer) func (c *Client) GrantTypes() []oidc.GrantType { return c.grantTypes } -//LoginURL will be called to redirect the user (agent) to the login UI -//you could implement some logic here to redirect the users to different login UIs depending on the client +// LoginURL will be called to redirect the user (agent) to the login UI +// you could implement some logic here to redirect the users to different login UIs depending on the client func (c *Client) LoginURL(id string) string { return c.loginURL(id) } -//AccessTokenType must return the type of access token the client uses (Bearer (opaque) or JWT) +// AccessTokenType must return the type of access token the client uses (Bearer (opaque) or JWT) func (c *Client) AccessTokenType() op.AccessTokenType { return c.accessTokenType } -//IDTokenLifetime must return the lifetime of the client's id_tokens +// IDTokenLifetime must return the lifetime of the client's id_tokens func (c *Client) IDTokenLifetime() time.Duration { return 1 * time.Hour } -//DevMode enables the use of non-compliant configs such as redirect_uris (e.g. http schema for user agent client) +// DevMode enables the use of non-compliant configs such as redirect_uris (e.g. http schema for user agent client) func (c *Client) DevMode() bool { return c.devMode } -//RestrictAdditionalIdTokenScopes allows specifying which custom scopes shall be asserted into the id_token +// RestrictAdditionalIdTokenScopes allows specifying which custom scopes shall be asserted into the id_token func (c *Client) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { return func(scopes []string) []string { return scopes } } -//RestrictAdditionalAccessTokenScopes allows specifying which custom scopes shall be asserted into the JWT access_token +// RestrictAdditionalAccessTokenScopes allows specifying which custom scopes shall be asserted into the JWT access_token func (c *Client) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { return func(scopes []string) []string { return scopes } } -//IsScopeAllowed enables Client specific custom scopes validation -//in this example we allow the CustomScope for all clients +// IsScopeAllowed enables Client specific custom scopes validation +// in this example we allow the CustomScope for all clients func (c *Client) IsScopeAllowed(scope string) bool { return scope == CustomScope } -//IDTokenUserinfoClaimsAssertion allows specifying if claims of scope profile, email, phone and address are asserted into the id_token -//even if an access token if issued which violates the OIDC Core spec +// IDTokenUserinfoClaimsAssertion allows specifying if claims of scope profile, email, phone and address are asserted into the id_token +// even if an access token if issued which violates the OIDC Core spec //(5.4. Requesting Claims using Scope Values: https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims) -//some clients though require that e.g. email is always in the id_token when requested even if an access_token is issued +// some clients though require that e.g. email is always in the id_token when requested even if an access_token is issued func (c *Client) IDTokenUserinfoClaimsAssertion() bool { return c.idTokenUserinfoClaimsAssertion } -//ClockSkew enables clients to instruct the OP to apply a clock skew on the various times and expirations +// ClockSkew enables clients to instruct the OP to apply a clock skew on the various times and expirations //(subtract from issued_at, add to expiration, ...) func (c *Client) ClockSkew() time.Duration { return c.clockSkew } -//RegisterClients enables you to register clients for the example implementation -//there are some clients (web and native) to try out different cases -//add more if necessary +// RegisterClients enables you to register clients for the example implementation +// there are some clients (web and native) to try out different cases +// add more if necessary +// +// RegisterClients should be called before the Storage is used so that there are +// no race conditions. func RegisterClients(registerClients ...*Client) { for _, client := range registerClients { clients[client.id] = client } } -//NativeClient will create a client of type native, which will always use PKCE and allow the use of refresh tokens -//user-defined redirectURIs may include: +// NativeClient will create a client of type native, which will always use PKCE and allow the use of refresh tokens +// user-defined redirectURIs may include: // - http://localhost without port specification (e.g. http://localhost/auth/callback) // - custom protocol (e.g. custom://auth/callback) //(the examples will be used as default, if none is provided) @@ -148,7 +151,7 @@ func NativeClient(id string, redirectURIs ...string) *Client { } return &Client{ id: id, - secret: "", //no secret needed (due to PKCE) + secret: "", // no secret needed (due to PKCE) redirectURIs: redirectURIs, applicationType: op.ApplicationTypeNative, authMethod: oidc.AuthMethodNone, @@ -162,8 +165,8 @@ func NativeClient(id string, redirectURIs ...string) *Client { } } -//WebClient will create a client of type web, which will always use Basic Auth and allow the use of refresh tokens -//user-defined redirectURIs may include: +// WebClient will create a client of type web, which will always use Basic Auth and allow the use of refresh tokens +// user-defined redirectURIs may include: // - http://localhost with port specification (e.g. http://localhost:9999/auth/callback) //(the example will be used as default, if none is provided) func WebClient(id, secret string, redirectURIs ...string) *Client { diff --git a/example/server/internal/oidc.go b/example/server/storage/oidc.go similarity index 87% rename from example/server/internal/oidc.go rename to example/server/storage/oidc.go index 5edf970..91afd90 100644 --- a/example/server/internal/oidc.go +++ b/example/server/storage/oidc.go @@ -1,4 +1,4 @@ -package internal +package storage import ( "time" @@ -11,11 +11,11 @@ import ( ) const ( - //CustomScope is an example for how to use custom scopes in this library + // CustomScope is an example for how to use custom scopes in this library //(in this scenario, when requested, it will return a custom claim) CustomScope = "custom_scope" - //CustomClaim is an example for how to return custom claims with this library + // CustomClaim is an example for how to return custom claims with this library CustomClaim = "custom_claim" ) @@ -44,11 +44,11 @@ func (a *AuthRequest) GetID() string { } func (a *AuthRequest) GetACR() string { - return "" //we won't handle acr in this example + return "" // we won't handle acr in this example } func (a *AuthRequest) GetAMR() []string { - //this example only uses password for authentication + // this example only uses password for authentication if a.passwordChecked { return []string{"pwd"} } @@ -56,7 +56,7 @@ func (a *AuthRequest) GetAMR() []string { } func (a *AuthRequest) GetAudience() []string { - return []string{a.ApplicationID} //this example will always just use the client_id as audience + return []string{a.ApplicationID} // this example will always just use the client_id as audience } func (a *AuthRequest) GetAuthTime() time.Time { @@ -84,7 +84,7 @@ func (a *AuthRequest) GetResponseType() oidc.ResponseType { } func (a *AuthRequest) GetResponseMode() oidc.ResponseMode { - return "" //we won't handle response mode in this example + return "" // we won't handle response mode in this example } func (a *AuthRequest) GetScopes() []string { @@ -100,7 +100,7 @@ func (a *AuthRequest) GetSubject() string { } func (a *AuthRequest) Done() bool { - return a.passwordChecked //this example only uses password for authentication + return a.passwordChecked // this example only uses password for authentication } func PromptToInternal(oidcPrompt oidc.SpaceDelimitedArray) []string { @@ -165,7 +165,7 @@ func CodeChallengeToOIDC(challenge *OIDCCodeChallenge) *oidc.CodeChallenge { } } -//RefreshTokenRequestFromBusiness will simply wrap the internal RefreshToken to implement the op.RefreshTokenRequest interface +// RefreshTokenRequestFromBusiness will simply wrap the storage RefreshToken to implement the op.RefreshTokenRequest interface func RefreshTokenRequestFromBusiness(token *RefreshToken) op.RefreshTokenRequest { return &RefreshTokenRequest{token} } diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go new file mode 100644 index 0000000..7b9d413 --- /dev/null +++ b/example/server/storage/storage.go @@ -0,0 +1,590 @@ +package storage + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "fmt" + "math/big" + "sync" + "time" + + "github.com/google/uuid" + "gopkg.in/square/go-jose.v2" + + "github.com/zitadel/oidc/pkg/oidc" + "github.com/zitadel/oidc/pkg/op" +) + +// serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant +// the corresponding private key is in the service-key1.json (for demonstration purposes) +var serviceKey1 = &rsa.PublicKey{ + N: func() *big.Int { + n, _ := new(big.Int).SetString("00f6d44fb5f34ac2033a75e73cb65ff24e6181edc58845e75a560ac21378284977bb055b1a75b714874e2a2641806205681c09abec76efd52cf40984edcf4c8ca09717355d11ac338f280d3e4c905b00543bdb8ee5a417496cb50cb0e29afc5a0d0471fd5a2fa625bd5281f61e6b02067d4fe7a5349eeae6d6a4300bcd86eef331", 16) + return n + }(), + E: 65537, +} + +// var _ op.Storage = &storage{} +// var _ op.ClientCredentialsStorage = &storage{} + +// storage implements the op.Storage interface +// typically you would implement this as a layer on top of your database +// for simplicity this example keeps everything in-memory +type Storage struct { + lock sync.Mutex + authRequests map[string]*AuthRequest + codes map[string]string + tokens map[string]*Token + clients map[string]*Client + userStore UserStore + services map[string]Service + refreshTokens map[string]*RefreshToken + signingKey signingKey +} + +type signingKey struct { + ID string + Algorithm string + Key *rsa.PrivateKey +} + +func NewStorage(userStore UserStore) *Storage { + key, _ := rsa.GenerateKey(rand.Reader, 2048) + return &Storage{ + authRequests: make(map[string]*AuthRequest), + codes: make(map[string]string), + tokens: make(map[string]*Token), + refreshTokens: make(map[string]*RefreshToken), + clients: clients, + userStore: userStore, + services: map[string]Service{ + userStore.ExampleClientID(): { + keys: map[string]*rsa.PublicKey{ + "key1": serviceKey1, + }, + }, + }, + signingKey: signingKey{ + ID: "id", + Algorithm: "RS256", + Key: key, + }, + } +} + +// CheckUsernamePassword implements the `authenticate` interface of the login +func (s *Storage) CheckUsernamePassword(username, password, id string) error { + s.lock.Lock() + defer s.lock.Unlock() + request, ok := s.authRequests[id] + if !ok { + return fmt.Errorf("request not found") + } + + // for demonstration purposes we'll check we'll have a simple user store and + // a plain text password. For real world scenarios, be sure to have the password + // hashed and salted (e.g. using bcrypt) + user := s.userStore.GetUserByUsername(username) + if user != nil && user.Password == password { + // be sure to set user id into the auth request after the user was checked, + // so that you'll be able to get more information about the user after the login + request.UserID = user.ID + + // you will have to change some state on the request to guide the user through possible multiple steps of the login process + // in this example we'll simply check the username / password and set a boolean to true + // therefore we will also just check this boolean if the request / login has been finished + request.passwordChecked = true + return nil + } + return fmt.Errorf("username or password wrong") +} + +// CreateAuthRequest implements the op.Storage interface +// it will be called after parsing and validation of the authentication request +func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) { + s.lock.Lock() + defer s.lock.Unlock() + + // typically, you'll fill your storage / storage model with the information of the passed object + request := authRequestToInternal(authReq, userID) + + // you'll also have to create a unique id for the request (this might be done by your database; we'll use a uuid) + request.ID = uuid.NewString() + + // and save it in your database (for demonstration purposed we will use a simple map) + s.authRequests[request.ID] = request + + // finally, return the request (which implements the AuthRequest interface of the OP + return request, nil +} + +// AuthRequestByID implements the op.Storage interface +// it will be called after the Login UI redirects back to the OIDC endpoint +func (s *Storage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) { + s.lock.Lock() + defer s.lock.Unlock() + request, ok := s.authRequests[id] + if !ok { + return nil, fmt.Errorf("request not found") + } + return request, nil +} + +// AuthRequestByCode implements the op.Storage interface +// it will be called after parsing and validation of the token request (in an authorization code flow) +func (s *Storage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) { + // for this example we read the id by code and then get the request by id + requestID, ok := func() (string, bool) { + s.lock.Lock() + defer s.lock.Unlock() + requestID, ok := s.codes[code] + return requestID, ok + }() + if !ok { + return nil, fmt.Errorf("code invalid or expired") + } + return s.AuthRequestByID(ctx, requestID) +} + +// SaveAuthCode implements the op.Storage interface +// it will be called after the authentication has been successful and before redirecting the user agent to the redirect_uri +//(in an authorization code flow) +func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error { + // for this example we'll just save the authRequestID to the code + s.lock.Lock() + defer s.lock.Unlock() + s.codes[code] = id + return nil +} + +// DeleteAuthRequest implements the op.Storage interface +// it will be called after creating the token response (id and access tokens) for a valid +//- authentication request (in an implicit flow) +//- token request (in an authorization code flow) +func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error { + // you can simply delete all reference to the auth request + s.lock.Lock() + defer s.lock.Unlock() + delete(s.authRequests, id) + for code, requestID := range s.codes { + if id == requestID { + delete(s.codes, code) + return nil + } + } + return nil +} + +// CreateAccessToken implements the op.Storage interface +// it will be called for all requests able to return an access token (Authorization Code Flow, Implicit Flow, JWT Profile, ...) +func (s *Storage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) { + var applicationID string + // if authenticated for an app (auth code / implicit flow) we must save the client_id to the token + authReq, ok := request.(*AuthRequest) + if ok { + applicationID = authReq.ApplicationID + } + token, err := s.accessToken(applicationID, "", request.GetSubject(), request.GetAudience(), request.GetScopes()) + if err != nil { + return "", time.Time{}, err + } + return token.ID, token.Expiration, nil +} + +// CreateAccessAndRefreshTokens implements the op.Storage interface +// it will be called for all requests able to return an access and refresh token (Authorization Code Flow, Refresh Token Request) +func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { + // get the information depending on the request type / implementation + applicationID, authTime, amr := getInfoFromRequest(request) + + // if currentRefreshToken is empty (Code Flow) we will have to create a new refresh token + if currentRefreshToken == "" { + refreshTokenID := uuid.NewString() + accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes()) + if err != nil { + return "", "", time.Time{}, err + } + refreshToken, err := s.createRefreshToken(accessToken, amr, authTime) + if err != nil { + return "", "", time.Time{}, err + } + return accessToken.ID, refreshToken, accessToken.Expiration, nil + } + + // if we get here, the currentRefreshToken was not empty, so the call is a refresh token request + // we therefore will have to check the currentRefreshToken and renew the refresh token + refreshToken, refreshTokenID, err := s.renewRefreshToken(currentRefreshToken) + if err != nil { + return "", "", time.Time{}, err + } + accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes()) + if err != nil { + return "", "", time.Time{}, err + } + return accessToken.ID, refreshToken, accessToken.Expiration, nil +} + +// TokenRequestByRefreshToken implements the op.Storage interface +// it will be called after parsing and validation of the refresh token request +func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { + s.lock.Lock() + defer s.lock.Unlock() + token, ok := s.refreshTokens[refreshToken] + if !ok { + return nil, fmt.Errorf("invalid refresh_token") + } + return RefreshTokenRequestFromBusiness(token), nil +} + +// TerminateSession implements the op.Storage interface +// it will be called after the user signed out, therefore the access and refresh token of the user of this client must be removed +func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error { + s.lock.Lock() + defer s.lock.Unlock() + for _, token := range s.tokens { + if token.ApplicationID == clientID && token.Subject == userID { + delete(s.tokens, token.ID) + delete(s.refreshTokens, token.RefreshTokenID) + return nil + } + } + return nil +} + +// RevokeToken implements the op.Storage interface +// it will be called after parsing and validation of the token revocation request +func (s *Storage) RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error { + // a single token was requested to be removed + s.lock.Lock() + defer s.lock.Unlock() + accessToken, ok := s.tokens[token] + if ok { + if accessToken.ApplicationID != clientID { + return oidc.ErrInvalidClient().WithDescription("token was not issued for this client") + } + // if it is an access token, just remove it + // you could also remove the corresponding refresh token if really necessary + delete(s.tokens, accessToken.ID) + return nil + } + refreshToken, ok := s.refreshTokens[token] + if !ok { + // if the token is neither an access nor a refresh token, just ignore it, the expected behaviour of + // being not valid (anymore) is achieved + return nil + } + if refreshToken.ApplicationID != clientID { + return oidc.ErrInvalidClient().WithDescription("token was not issued for this client") + } + // if it is a refresh token, you will have to remove the access token as well + delete(s.refreshTokens, refreshToken.ID) + for _, accessToken := range s.tokens { + if accessToken.RefreshTokenID == refreshToken.ID { + delete(s.tokens, accessToken.ID) + return nil + } + } + return nil +} + +// GetSigningKey implements the op.Storage interface +// it will be called when creating the OpenID Provider +func (s *Storage) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey) { + // in this example the signing key is a static rsa.PrivateKey and the algorithm used is RS256 + // you would obviously have a more complex implementation and store / retrieve the key from your database as well + // + // the idea of the signing key channel is, that you can (with what ever mechanism) rotate your signing key and + // switch the key of the signer via this channel + keyCh <- jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(s.signingKey.Algorithm), // always tell the signer with algorithm to use + Key: jose.JSONWebKey{ + KeyID: s.signingKey.ID, // always give the key an id so, that it will include it in the token header as `kid` claim + Key: s.signingKey.Key, + }, + } +} + +// GetKeySet implements the op.Storage interface +// it will be called to get the current (public) keys, among others for the keys_endpoint or for validating access_tokens on the userinfo_endpoint, ... +func (s *Storage) GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) { + // as mentioned above, this example only has a single signing key without key rotation, + // so it will directly use its public key + // + // when using key rotation you typically would store the public keys alongside the private keys in your database + // and give both of them an expiration date, with the public key having a longer lifetime (e.g. rotate private key every + return &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: s.signingKey.ID, + Algorithm: s.signingKey.Algorithm, + Use: oidc.KeyUseSignature, + Key: &s.signingKey.Key.PublicKey, + }, + }, + }, nil +} + +// GetClientByClientID implements the op.Storage interface +// it will be called whenever information (type, redirect_uris, ...) about the client behind the client_id is needed +func (s *Storage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) { + s.lock.Lock() + defer s.lock.Unlock() + client, ok := s.clients[clientID] + if !ok { + return nil, fmt.Errorf("client not found") + } + return client, nil +} + +// AuthorizeClientIDSecret implements the op.Storage interface +// it will be called for validating the client_id, client_secret on token or introspection requests +func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error { + s.lock.Lock() + defer s.lock.Unlock() + client, ok := s.clients[clientID] + if !ok { + return fmt.Errorf("client not found") + } + // for this example we directly check the secret + // obviously you would not have the secret in plain text, but rather hashed and salted (e.g. using bcrypt) + if client.secret != clientSecret { + return fmt.Errorf("invalid secret") + } + return nil +} + +// SetUserinfoFromScopes implements the op.Storage interface +// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check +func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, userID, clientID string, scopes []string) error { + return s.setUserinfo(ctx, userinfo, userID, clientID, scopes) +} + +// SetUserinfoFromToken implements the op.Storage interface +// it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function +func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, tokenID, subject, origin string) error { + token, ok := func() (*Token, bool) { + s.lock.Lock() + defer s.lock.Unlock() + token, ok := s.tokens[tokenID] + return token, ok + }() + if !ok { + return fmt.Errorf("token is invalid or has expired") + } + // the userinfo endpoint should support CORS. If it's not possible to specify a specific origin in the CORS handler, + // and you have to specify a wildcard (*) origin, then you could also check here if the origin which called the userinfo endpoint here directly + // note that the origin can be empty (if called by a web client) + // + // if origin != "" { + // client, ok := s.clients[token.ApplicationID] + // if !ok { + // return fmt.Errorf("client not found") + // } + // if err := checkAllowedOrigins(client.allowedOrigins, origin); err != nil { + // return err + // } + //} + return s.setUserinfo(ctx, userinfo, token.Subject, token.ApplicationID, token.Scopes) +} + +// SetIntrospectionFromToken implements the op.Storage interface +// it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function +func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspection oidc.IntrospectionResponse, tokenID, subject, clientID string) error { + token, ok := func() (*Token, bool) { + s.lock.Lock() + defer s.lock.Unlock() + token, ok := s.tokens[tokenID] + return token, ok + }() + if !ok { + return fmt.Errorf("token is invalid or has expired") + } + // check if the client is part of the requested audience + for _, aud := range token.Audience { + if aud == clientID { + // the introspection response only has to return a boolean (active) if the token is active + // this will automatically be done by the library if you don't return an error + // you can also return further information about the user / associated token + // e.g. the userinfo (equivalent to userinfo endpoint) + err := s.setUserinfo(ctx, introspection, subject, clientID, token.Scopes) + if err != nil { + return err + } + //...and also the requested scopes... + introspection.SetScopes(token.Scopes) + //...and the client the token was issued to + introspection.SetClientID(token.ApplicationID) + return nil + } + } + return fmt.Errorf("token is not valid for this client") +} + +// GetPrivateClaimsFromScopes implements the op.Storage interface +// it will be called for the creation of a JWT access token to assert claims for custom scopes +func (s *Storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) { + for _, scope := range scopes { + switch scope { + case CustomScope: + claims = appendClaim(claims, CustomClaim, customClaim(clientID)) + } + } + return claims, nil +} + +// GetKeyByIDAndUserID implements the op.Storage interface +// it will be called to validate the signatures of a JWT (JWT Profile Grant and Authentication) +func (s *Storage) GetKeyByIDAndUserID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) { + s.lock.Lock() + defer s.lock.Unlock() + service, ok := s.services[clientID] + if !ok { + return nil, fmt.Errorf("clientID not found") + } + key, ok := service.keys[keyID] + if !ok { + return nil, fmt.Errorf("key not found") + } + return &jose.JSONWebKey{ + KeyID: keyID, + Use: "sig", + Key: key, + }, nil +} + +// ValidateJWTProfileScopes implements the op.Storage interface +// it will be called to validate the scopes of a JWT Profile Authorization Grant request +func (s *Storage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) { + allowedScopes := make([]string, 0) + for _, scope := range scopes { + if scope == oidc.ScopeOpenID { + allowedScopes = append(allowedScopes, scope) + } + } + return allowedScopes, nil +} + +// Health implements the op.Storage interface +func (s *Storage) Health(ctx context.Context) error { + return nil +} + +// createRefreshToken will store a refresh_token in-memory based on the provided information +func (s *Storage) createRefreshToken(accessToken *Token, amr []string, authTime time.Time) (string, error) { + s.lock.Lock() + defer s.lock.Unlock() + token := &RefreshToken{ + ID: accessToken.RefreshTokenID, + Token: accessToken.RefreshTokenID, + AuthTime: authTime, + AMR: amr, + ApplicationID: accessToken.ApplicationID, + UserID: accessToken.Subject, + Audience: accessToken.Audience, + Expiration: time.Now().Add(5 * time.Hour), + Scopes: accessToken.Scopes, + } + s.refreshTokens[token.ID] = token + return token.Token, nil +} + +// renewRefreshToken checks the provided refresh_token and creates a new one based on the current +func (s *Storage) renewRefreshToken(currentRefreshToken string) (string, string, error) { + s.lock.Lock() + defer s.lock.Unlock() + refreshToken, ok := s.refreshTokens[currentRefreshToken] + if !ok { + return "", "", fmt.Errorf("invalid refresh token") + } + // deletes the refresh token and all access tokens which were issued based on this refresh token + delete(s.refreshTokens, currentRefreshToken) + for _, token := range s.tokens { + if token.RefreshTokenID == currentRefreshToken { + delete(s.tokens, token.ID) + break + } + } + // creates a new refresh token based on the current one + token := uuid.NewString() + refreshToken.Token = token + s.refreshTokens[token] = refreshToken + return token, refreshToken.ID, nil +} + +// accessToken will store an access_token in-memory based on the provided information +func (s *Storage) accessToken(applicationID, refreshTokenID, subject string, audience, scopes []string) (*Token, error) { + s.lock.Lock() + defer s.lock.Unlock() + token := &Token{ + ID: uuid.NewString(), + ApplicationID: applicationID, + RefreshTokenID: refreshTokenID, + Subject: subject, + Audience: audience, + Expiration: time.Now().Add(5 * time.Minute), + Scopes: scopes, + } + s.tokens[token.ID] = token + return token, nil +} + +// setUserinfo sets the info based on the user, scopes and if necessary the clientID +func (s *Storage) setUserinfo(ctx context.Context, userInfo oidc.UserInfoSetter, userID, clientID string, scopes []string) (err error) { + s.lock.Lock() + defer s.lock.Unlock() + user := s.userStore.GetUserByID(userID) + if user == nil { + return fmt.Errorf("user not found") + } + for _, scope := range scopes { + switch scope { + case oidc.ScopeOpenID: + userInfo.SetSubject(user.ID) + case oidc.ScopeEmail: + userInfo.SetEmail(user.Email, user.EmailVerified) + case oidc.ScopeProfile: + userInfo.SetPreferredUsername(user.Username) + userInfo.SetName(user.FirstName + " " + user.LastName) + userInfo.SetFamilyName(user.LastName) + userInfo.SetGivenName(user.FirstName) + userInfo.SetLocale(user.PreferredLanguage) + case oidc.ScopePhone: + userInfo.SetPhone(user.Phone, user.PhoneVerified) + case CustomScope: + // you can also have a custom scope and assert public or custom claims based on that + userInfo.AppendClaims(CustomClaim, customClaim(clientID)) + } + } + return nil +} + +// getInfoFromRequest returns the clientID, authTime and amr depending on the op.TokenRequest type / implementation +func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) { + authReq, ok := req.(*AuthRequest) // Code Flow (with scope offline_access) + if ok { + return authReq.ApplicationID, authReq.authTime, authReq.GetAMR() + } + refreshReq, ok := req.(*RefreshTokenRequest) // Refresh Token Request + if ok { + return refreshReq.ApplicationID, refreshReq.AuthTime, refreshReq.AMR + } + return "", time.Time{}, nil +} + +// customClaim demonstrates how to return custom claims based on provided information +func customClaim(clientID string) map[string]interface{} { + return map[string]interface{}{ + "client": clientID, + "other": "stuff", + } +} + +func appendClaim(claims map[string]interface{}, claim string, value interface{}) map[string]interface{} { + if claims == nil { + claims = make(map[string]interface{}) + } + claims[claim] = value + return claims +} diff --git a/example/server/internal/token.go b/example/server/storage/token.go similarity index 96% rename from example/server/internal/token.go rename to example/server/storage/token.go index 09e675a..ad907e3 100644 --- a/example/server/internal/token.go +++ b/example/server/storage/token.go @@ -1,4 +1,4 @@ -package internal +package storage import "time" diff --git a/example/server/storage/user.go b/example/server/storage/user.go new file mode 100644 index 0000000..423af59 --- /dev/null +++ b/example/server/storage/user.go @@ -0,0 +1,71 @@ +package storage + +import ( + "crypto/rsa" + + "golang.org/x/text/language" +) + +type User struct { + ID string + Username string + Password string + FirstName string + LastName string + Email string + EmailVerified bool + Phone string + PhoneVerified bool + PreferredLanguage language.Tag +} + +type Service struct { + keys map[string]*rsa.PublicKey +} + +type UserStore interface { + GetUserByID(string) *User + GetUserByUsername(string) *User + ExampleClientID() string +} + +type userStore struct { + users map[string]*User +} + +func NewUserStore() UserStore { + return userStore{ + users: map[string]*User{ + "id1": { + ID: "id1", + Username: "test-user", + Password: "verysecure", + FirstName: "Test", + LastName: "User", + Email: "test-user@zitadel.ch", + EmailVerified: true, + Phone: "", + PhoneVerified: false, + PreferredLanguage: language.German, + }, + }, + } +} + +// ExampleClientID is only used in the example server +func (u userStore) ExampleClientID() string { + return "service" +} + +func (u userStore) GetUserByID(id string) *User { + return u.users[id] +} + +func (u userStore) GetUserByUsername(username string) *User { + for _, user := range u.users { + if user.Username == username { + return user + } + } + return nil +}