diff --git a/example/client/app/app.go b/example/client/app/app.go index 9a6dd97..4c0831b 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -10,7 +10,6 @@ import ( "github.com/google/uuid" "github.com/sirupsen/logrus" - "golang.org/x/oauth2" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/rp" @@ -30,17 +29,13 @@ func main() { ctx := context.Background() - rpConfig := &rp.Configuration{ - Issuer: issuer, - Config: &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath), - Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}, - }, - } + redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) + scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail} cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) - provider, err := rp.NewRelayingParty(rpConfig, rp.WithCookieHandler(cookieHandler), rp.WithPKCE(cookieHandler), rp.WithVerifierOpts(rp.WithIssuedAtOffset(-3*time.Minute))) //, + provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, + rp.WithPKCE(cookieHandler), + rp.WithVerifierOpts(rp.WithIssuedAtOffset(5*time.Second)), + ) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } diff --git a/example/client/github/github.go b/example/client/github/github.go index 56d4e8d..5489389 100644 --- a/example/client/github/github.go +++ b/example/client/github/github.go @@ -5,10 +5,14 @@ import ( "fmt" "os" - "github.com/caos/oidc/pkg/cli" - "github.com/caos/oidc/pkg/rp" "github.com/google/go-github/v31/github" + "github.com/google/uuid" + "golang.org/x/oauth2" githubOAuth "golang.org/x/oauth2/github" + + "github.com/caos/oidc/pkg/rp" + "github.com/caos/oidc/pkg/rp/cli" + "github.com/caos/oidc/pkg/utils" ) var ( @@ -21,23 +25,32 @@ func main() { clientSecret := os.Getenv("CLIENT_SECRET") port := os.Getenv("PORT") - rpConfig := &rp.Config{ + rpConfig := &oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath), + RedirectURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath), Scopes: []string{"repo", "repo_deployment"}, - Endpoints: githubOAuth.Endpoint, + Endpoint: githubOAuth.Endpoint, } - oauth2Client := cli.CodeFlowForClient(rpConfig, key, callbackPath, port) - - client := github.NewClient(oauth2Client) - ctx := context.Background() - _, _, err := client.Users.Get(ctx, "") + cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) + relayingParty, err := rp.NewRelayingPartyOAuth(rpConfig, rp.WithCookieHandler(cookieHandler)) if err != nil { - fmt.Println("OAuth flow failed") - } else { - fmt.Println("OAuth flow success") + fmt.Printf("error creating relaying party: %v", err) + return } + state := func() string { + return uuid.New().String() + } + token := cli.CodeFlow(relayingParty, callbackPath, port, state) + + client := github.NewClient(relayingParty.Client(ctx, token.Token)) + + _, _, err = client.Users.Get(ctx, "") + if err != nil { + fmt.Printf("error %v", err) + return + } + fmt.Println("call succeeded") } diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go deleted file mode 100644 index 4d642ba..0000000 --- a/pkg/cli/cli.go +++ /dev/null @@ -1,121 +0,0 @@ -package cli - -import ( - "context" - "fmt" - "log" - "net/http" - "strings" - "time" - - "github.com/google/uuid" - "github.com/sirupsen/logrus" - - "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/rp" - "github.com/caos/oidc/pkg/utils" -) - -func CodeFlow(rpc *rp.Configuration, key []byte, callbackPath string, port string) *oidc.Tokens { - cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) - provider, err := rp.NewRelayingParty(rpc, rp.WithCookieHandler(cookieHandler)) - if err != nil { - logrus.Fatalf("error creating provider %s", err.Error()) - } - - return codeFlow(provider, callbackPath, port) -} - -func TokenForClient(rpc *rp.Configuration, key []byte, token *oidc.Tokens) *http.Client { - cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) - provider, err := rp.NewRelayingParty(rpc, rp.WithCookieHandler(cookieHandler)) - if err != nil { - logrus.Fatalf("error creating provider %s", err.Error()) - } - - return provider.Client(context.Background(), token.Token) -} - -func CodeFlowForClient(rpc *rp.Configuration, key []byte, callbackPath string, port string) *http.Client { - cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) - provider, err := rp.NewRelayingParty(rpc, rp.WithCookieHandler(cookieHandler)) - if err != nil { - logrus.Fatalf("error creating provider %s", err.Error()) - } - token := codeFlow(provider, callbackPath, port) - - return provider.Client(context.Background(), token.Token) -} - -func codeFlow(provider rp.RelayingParty, callbackPath string, port string) *oidc.Tokens { - loginPath := "/login" - portStr := port - if !strings.HasPrefix(port, ":") { - portStr = strings.Join([]string{":", portStr}, "") - } - - getToken, setToken := getAndSetTokens() - - state := func() string { - return uuid.New().String() - } - http.Handle(loginPath, rp.AuthURLHandler(state, provider)) - - marshal := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) { - setToken(w, tokens) - } - http.Handle(callbackPath, rp.CodeExchangeHandler(marshal, provider)) - - // start http-server - stopHttpServer := startHttpServer(portStr) - - // open browser in different window - utils.OpenBrowser(strings.Join([]string{"http://localhost", portStr, loginPath}, "")) - - // wait until user is logged into browser - ret := getToken() - - // stop http-server as no callback is needed anymore - stopHttpServer() - - // return tokens - return ret -} - -func startHttpServer(port string) func() { - srv := &http.Server{Addr: port} - go func() { - - // always returns error. ErrServerClosed on graceful close - if err := srv.ListenAndServe(); err != http.ErrServerClosed { - // unexpected error. port in use? - log.Fatalf("ListenAndServe(): %v", err) - } - }() - - return func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := srv.Shutdown(ctx); err != nil { - log.Fatalf("Shutdown(): %v", err) - } - } -} - -func getAndSetTokens() (func() *oidc.Tokens, func(w http.ResponseWriter, tokens *oidc.Tokens)) { - marshalChan := make(chan *oidc.Tokens) - - getToken := func() *oidc.Tokens { - return <-marshalChan - } - setToken := func(w http.ResponseWriter, tokens *oidc.Tokens) { - marshalChan <- tokens - - msg := "

Success!

" - msg = msg + "

You are authenticated and can now return to the CLI.

" - fmt.Fprintf(w, msg) - } - - return getToken, setToken -} diff --git a/pkg/rp/cli/cli.go b/pkg/rp/cli/cli.go new file mode 100644 index 0000000..4b00ba0 --- /dev/null +++ b/pkg/rp/cli/cli.go @@ -0,0 +1,35 @@ +package cli + +import ( + "context" + "net/http" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" + "github.com/caos/oidc/pkg/utils" +) + +const ( + loginPath = "/login" +) + +func CodeFlow(relayingParty rp.RelayingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var token *oidc.Tokens + callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) { + token = tokens + msg := "

Success!

" + msg = msg + "

You are authenticated and can now return to the CLI.

" + w.Write([]byte(msg)) + } + http.Handle(loginPath, rp.AuthURLHandler(stateProvider, relayingParty)) + http.Handle(callbackPath, rp.CodeExchangeHandler(callback, relayingParty)) + + utils.StartServer(ctx, port) + + utils.OpenBrowser("http://localhost:" + port + loginPath) + + return token +} diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go index 43d0c97..fc85a4a 100644 --- a/pkg/rp/relaying_party.go +++ b/pkg/rp/relaying_party.go @@ -39,30 +39,31 @@ type RelayingParty interface { ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) } +type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) + var ( - DefaultErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) { + DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) { http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError) } ) type relayingParty struct { - endpoints Endpoints - - config *Configuration - pkce bool + issuer string + endpoints Endpoints + oauthConfig *oauth2.Config + oauth2Only bool + pkce bool httpClient *http.Client cookieHandler *utils.CookieHandler - - errorHandler func(http.ResponseWriter, *http.Request, string, string, string) + errorHandler func(http.ResponseWriter, *http.Request, string, string, string) idTokenVerifier IDTokenVerifier verifierOpts []VerifierOption - oauth2Only bool } func (rp *relayingParty) OAuthConfig() *oauth2.Config { - return rp.config.Config + return rp.oauthConfig } func (rp *relayingParty) IsPKCE() bool { @@ -83,97 +84,69 @@ func (rp *relayingParty) IsOAuth2Only() bool { func (rp *relayingParty) IDTokenVerifier() IDTokenVerifier { if rp.idTokenVerifier == nil { - rp.idTokenVerifier = NewIDTokenVerifier(rp.config.Issuer, rp.config.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) + rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) } return rp.idTokenVerifier } func (rp *relayingParty) Client(ctx context.Context, token *oauth2.Token) *http.Client { - return rp.config.Config.Client(ctx, token) + return rp.oauthConfig.Client(ctx, token) } func (rp *relayingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) { + if rp.errorHandler == nil { + rp.errorHandler = DefaultErrorHandler + } return rp.errorHandler } -//NewRelayingParty creates a DelegationTokenExchangeRP with the given -//Config and possible configOptions -//it will run discovery on the provided issuer if AuthURL and TokenURL are not set -//if no verifier is provided using the options the `DefaultVerifier` is set -func NewRelayingParty(config *Configuration, options ...Option) (RelayingParty, error) { - isOpenID := isOpenID(config.Scopes) - +//NewRelayingPartyOAuth creates an (OAuth2) RelayingParty with the given +//OAuth2 Config and possible configOptions +//it will use the AuthURL and TokenURL set in config +func NewRelayingPartyOAuth(config *oauth2.Config, options ...Option) (RelayingParty, error) { rp := &relayingParty{ - config: config, - httpClient: utils.DefaultHTTPClient, - oauth2Only: !isOpenID, + oauthConfig: config, + httpClient: utils.DefaultHTTPClient, + oauth2Only: true, } for _, optFunc := range options { optFunc(rp) } - if isOpenID && config.Endpoint.AuthURL == "" && config.Endpoint.TokenURL == "" { - endpoints, err := Discover(config.Issuer, rp.httpClient) - if err != nil { - return nil, err - } - rp.config.Endpoint = endpoints.Endpoint - rp.endpoints = endpoints - } - - if rp.errorHandler == nil { - rp.errorHandler = DefaultErrorHandler - } - - if isOpenID && rp.idTokenVerifier == nil { - rp.idTokenVerifier = NewIDTokenVerifier(config.Issuer, config.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL)) - } - return rp, nil } -func NewRelayingParty2(clientID, clientSecret, redirectURI string, options ...Option) (RelayingParty, error) { +//NewRelayingPartyOIDC creates an (OIDC) RelayingParty with the given +//issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions +//it will run discovery on the provided issuer and use the found endpoints +func NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelayingParty, error) { rp := &relayingParty{ - config: &Configuration{ - Config: &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURL: redirectURI, - }, + issuer: issuer, + oauthConfig: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURI, + Scopes: scopes, }, httpClient: utils.DefaultHTTPClient, - oauth2Only: true, + oauth2Only: false, } for _, optFunc := range options { optFunc(rp) } - if !rp.oauth2Only && rp.config.Endpoint.AuthURL == "" && rp.config.Endpoint.TokenURL == "" { - endpoints, err := Discover(rp.config.Issuer, rp.httpClient) - if err != nil { - return nil, err - } - rp.config.Endpoint = endpoints.Endpoint - rp.endpoints = endpoints - } - - if rp.errorHandler == nil { - rp.errorHandler = DefaultErrorHandler + endpoints, err := Discover(rp.issuer, rp.httpClient) + if err != nil { + return nil, err } + rp.oauthConfig.Endpoint = endpoints.Endpoint + rp.endpoints = endpoints return rp, nil } -func WithOIDC(issuer string, scopes []string) Option { - return func(rp *relayingParty) { - rp.config.Issuer = issuer - rp.config.Scopes = scopes - rp.oauth2Only = false - } -} - //DefaultRPOpts is the type for providing dynamic options to the DefaultRP type Option func(*relayingParty) @@ -201,6 +174,12 @@ func WithHTTPClient(client *http.Client) Option { } } +func WithErrorHandler(errorHandler ErrorHandler) Option { + return func(rp *relayingParty) { + rp.errorHandler = errorHandler + } +} + func WithVerifierOpts(opts ...VerifierOption) Option { return func(rp *relayingParty) { rp.verifierOpts = opts @@ -414,12 +393,3 @@ func WithCodeVerifier(codeVerifier string) CodeExchangeOpt { return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("code_verifier", codeVerifier)} } } - -func isOpenID(scopes []string) bool { - for _, scope := range scopes { - if scope == oidc.ScopeOpenID { - return true - } - } - return false -} diff --git a/pkg/utils/http.go b/pkg/utils/http.go index 5da12af..993febb 100644 --- a/pkg/utils/http.go +++ b/pkg/utils/http.go @@ -1,9 +1,11 @@ package utils import ( + "context" "encoding/json" "fmt" "io/ioutil" + "log" "net/http" "net/url" "strings" @@ -79,3 +81,18 @@ func URLEncodeResponse(resp interface{}, encoder Encoder) (string, error) { v := url.Values(values) return v.Encode(), nil } + +func StartServer(ctx context.Context, port string) { + server := &http.Server{Addr: port} + go func() { + if err := server.ListenAndServe(); err != http.ErrServerClosed { + log.Fatalf("ListenAndServe(): %v", err) + } + }() + + go func() { + <-ctx.Done() + err := server.Shutdown(ctx) + log.Fatalf("Shutdown(): %v", err) + }() +}