From 2cbf96e448ac2ae6a67ae5a68648d6b925afe5d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Thu, 1 Aug 2024 13:34:12 +0300 Subject: [PATCH] feat(op): allow returning of parent errors to client --- pkg/oidc/error.go | 24 ++++++++++++++++++++++++ pkg/oidc/error_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 2f0572d..c609f8e 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -1,6 +1,7 @@ package oidc import ( + "encoding/json" "errors" "fmt" "log/slog" @@ -133,6 +134,24 @@ type Error struct { Description string `json:"error_description,omitempty" schema:"error_description,omitempty"` State string `json:"state,omitempty" schema:"state,omitempty"` redirectDisabled bool `schema:"-"` + returnParent bool `schema:"-"` +} + +func (e *Error) MarshalJSON() ([]byte, error) { + m := struct { + Error errorType `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + State string `json:"state,omitempty"` + Parent string `json:"parent,omitempty"` + }{ + Error: e.ErrorType, + ErrorDescription: e.Description, + State: e.State, + } + if e.returnParent { + m.Parent = e.Parent.Error() + } + return json.Marshal(m) } func (e *Error) Error() string { @@ -165,6 +184,11 @@ func (e *Error) WithParent(err error) *Error { return e } +func (e *Error) WithReturnParentToClient(b bool) *Error { + e.returnParent = b + return e +} + func (e *Error) WithDescription(desc string, args ...any) *Error { e.Description = fmt.Sprintf(desc, args...) return e diff --git a/pkg/oidc/error_test.go b/pkg/oidc/error_test.go index 2eeb4e6..40d30b1 100644 --- a/pkg/oidc/error_test.go +++ b/pkg/oidc/error_test.go @@ -1,11 +1,14 @@ package oidc import ( + "encoding/json" + "errors" "io" "log/slog" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDefaultToServerError(t *testing.T) { @@ -151,3 +154,39 @@ func TestError_LogValue(t *testing.T) { }) } } + +func TestError_MarshalJSON(t *testing.T) { + tests := []struct { + name string + e *Error + want string + }{ + { + name: "simple error", + e: ErrAccessDenied(), + want: `{"error":"access_denied","error_description":"The authorization request was denied."}`, + }, + { + name: "with description", + e: ErrAccessDenied().WithDescription("oops"), + want: `{"error":"access_denied","error_description":"oops"}`, + }, + { + name: "with parent", + e: ErrServerError().WithParent(errors.New("oops")), + want: `{"error":"server_error"}`, + }, + { + name: "with return parent", + e: ErrServerError().WithParent(errors.New("oops")).WithReturnParentToClient(true), + want: `{"error":"server_error","parent":"oops"}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.e) + require.NoError(t, err) + assert.JSONEq(t, tt.want, string(got)) + }) + } +}