Skip to content

Commit cd28916

Browse files
hrntknrCopilot
andauthored
fix: improve authentication flow and session handling (#45)
* fix: improve authentication flow and session handling - Add authorization check after user authentication - Fix redirect URL handling in session management - Refactor template rendering into dedicated methods - Improve logout behavior to redirect to login page - Clean up session management by deleting specific keys instead of clearing all * Update pkg/auth/auth.go Co-authored-by: Copilot <[email protected]> * refactor: improve test organization and fix authorization flow handling --------- Co-authored-by: Copilot <[email protected]>
1 parent 9038812 commit cd28916

2 files changed

Lines changed: 90 additions & 71 deletions

File tree

pkg/auth/auth.go

Lines changed: 75 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ var templateFS embed.FS
1818
type AuthRouter struct {
1919
passwordHash []string
2020
providers []Provider
21-
template *template.Template
21+
loginTemplate *template.Template
2222
unauthorizedTemplate *template.Template
2323
}
2424

@@ -36,7 +36,7 @@ func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, e
3636
return &AuthRouter{
3737
passwordHash: passwordHash,
3838
providers: providers,
39-
template: tmpl,
39+
loginTemplate: tmpl,
4040
unauthorizedTemplate: unauthorizedTmpl,
4141
}, nil
4242
}
@@ -82,11 +82,28 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
8282
c.Error(err)
8383
return
8484
}
85+
ok, err := provider.Authorization(userID)
86+
if err != nil {
87+
c.Error(err)
88+
return
89+
}
90+
if !ok {
91+
a.renderUnauthorized(c, userID, provider.Name())
92+
return
93+
}
8594
session.Set(SessionKeyProvider, provider.Name())
8695
session.Set(SessionKeyUserID, userID)
87-
session.Save()
8896
redirectURL := session.Get(SessionKeyRedirectURL)
89-
c.Redirect(http.StatusFound, redirectURL.(string))
97+
if redirectURL != nil {
98+
session.Delete(SessionKeyRedirectURL)
99+
}
100+
session.Save()
101+
102+
if redirectURL == nil {
103+
c.Redirect(http.StatusFound, "/")
104+
} else {
105+
c.Redirect(http.StatusFound, redirectURL.(string))
106+
}
90107
})
91108

92109
router.GET(provider.AuthURL(), func(c *gin.Context) {
@@ -118,29 +135,12 @@ func (a *AuthRouter) getProvider(name string) Provider {
118135
return nil
119136
}
120137

121-
type templateData struct {
122-
Providers []Provider
123-
HasPassword bool
124-
PasswordError string
125-
}
126-
127138
func (a *AuthRouter) handleLogin(c *gin.Context) {
128139
if c.Request.Method == "POST" {
129140
a.handleLoginPost(c)
130141
return
131142
}
132-
133-
data := templateData{
134-
Providers: a.providers,
135-
HasPassword: len(a.passwordHash) > 0,
136-
PasswordError: "",
137-
}
138-
139-
c.Header("Content-Type", "text/html; charset=utf-8")
140-
if err := a.template.Execute(c.Writer, data); err != nil {
141-
c.AbortWithError(http.StatusInternalServerError, err)
142-
return
143-
}
143+
a.renderLogin(c, "")
144144
}
145145

146146
func (a *AuthRouter) handleLoginPost(c *gin.Context) {
@@ -165,39 +165,32 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
165165
}
166166

167167
if errorMessage != "" {
168-
data := templateData{
169-
Providers: a.providers,
170-
HasPassword: len(a.passwordHash) > 0,
171-
PasswordError: errorMessage,
172-
}
173-
174-
c.Header("Content-Type", "text/html; charset=utf-8")
175-
c.Status(http.StatusBadRequest)
176-
if err := a.template.Execute(c.Writer, data); err != nil {
177-
c.AbortWithError(http.StatusInternalServerError, err)
178-
return
179-
}
168+
a.renderLogin(c, errorMessage)
180169
return
181170
}
182171

183172
session := sessions.Default(c)
184173
session.Set(SessionKeyProvider, PasswordProvider)
185174
session.Set(SessionKeyUserID, PasswordUserID)
175+
redirectURL := session.Get(SessionKeyRedirectURL)
176+
if redirectURL != nil {
177+
session.Delete(SessionKeyRedirectURL)
178+
}
186179
session.Save()
187180

188-
redirectURL := session.Get(SessionKeyRedirectURL)
189181
if redirectURL == nil {
190182
c.Redirect(http.StatusFound, "/")
191-
return
183+
} else {
184+
c.Redirect(http.StatusFound, redirectURL.(string))
192185
}
193-
c.Redirect(http.StatusFound, redirectURL.(string))
194186
}
195187

196188
func (a *AuthRouter) handleLogout(c *gin.Context) {
197189
session := sessions.Default(c)
198-
session.Clear()
190+
session.Delete(SessionKeyProvider)
191+
session.Delete(SessionKeyUserID)
199192
session.Save()
200-
c.String(http.StatusOK, "Logged out")
193+
c.Redirect(http.StatusFound, LoginEndpoint)
201194
}
202195

203196
func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
@@ -229,22 +222,52 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
229222
return
230223
}
231224
if !ok {
232-
data := struct {
233-
UserID string
234-
Provider string
235-
}{
236-
UserID: userID.(string),
237-
Provider: providerName.(string),
238-
}
239-
c.Header("Content-Type", "text/html; charset=utf-8")
240-
c.Status(http.StatusForbidden)
241-
if err := a.unauthorizedTemplate.Execute(c.Writer, data); err != nil {
242-
c.AbortWithError(http.StatusInternalServerError, err)
243-
return
244-
}
225+
a.renderUnauthorized(c, userID.(string), providerName.(string))
245226
c.Abort()
246227
return
247228
}
248229
c.Next()
249230
}
250231
}
232+
233+
type loginTemplateData struct {
234+
Providers []Provider
235+
HasPassword bool
236+
PasswordError string
237+
}
238+
239+
type unauthorizedTemplateData struct {
240+
UserID string
241+
Provider string
242+
}
243+
244+
func (a *AuthRouter) renderLogin(c *gin.Context, passwordError string) {
245+
data := loginTemplateData{
246+
Providers: a.providers,
247+
HasPassword: len(a.passwordHash) > 0,
248+
PasswordError: passwordError,
249+
}
250+
c.Header("Content-Type", "text/html; charset=utf-8")
251+
if passwordError != "" {
252+
c.Status(http.StatusBadRequest)
253+
} else {
254+
c.Status(http.StatusOK)
255+
}
256+
if err := a.loginTemplate.Execute(c.Writer, data); err != nil {
257+
c.AbortWithError(http.StatusInternalServerError, err)
258+
return
259+
}
260+
}
261+
262+
func (a *AuthRouter) renderUnauthorized(c *gin.Context, userID, providerName string) {
263+
data := unauthorizedTemplateData{
264+
UserID: userID,
265+
Provider: providerName,
266+
}
267+
c.Header("Content-Type", "text/html; charset=utf-8")
268+
c.Status(http.StatusForbidden)
269+
if err := a.unauthorizedTemplate.Execute(c.Writer, data); err != nil {
270+
c.AbortWithError(http.StatusInternalServerError, err)
271+
return
272+
}
273+
}

pkg/auth/auth_test.go

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,15 @@ func TestAuthenticationFlow(t *testing.T) {
7979
defer ctrl.Finish()
8080

8181
// Create mock provider
82+
mockToken := &oauth2.Token{AccessToken: "test-token"}
8283
mockProvider := NewMockProvider(ctrl)
8384
mockProvider.EXPECT().Name().Return("test").AnyTimes()
8485
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
8586
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
87+
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
88+
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
89+
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("test-user", nil)
90+
mockProvider.EXPECT().Authorization("test-user").Return(true, nil).AnyTimes()
8691

8792
// Create AuthRouter
8893
authRouter, err := NewAuthRouter(nil, mockProvider)
@@ -103,8 +108,6 @@ func TestAuthenticationFlow(t *testing.T) {
103108
require.Equal(t, http.StatusFound, resp.StatusCode)
104109

105110
// Step 2: Start authentication
106-
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
107-
108111
resp, err = client.Get(server.URL + "/.auth/test")
109112
require.NoError(t, err)
110113
resp.Body.Close()
@@ -115,23 +118,15 @@ func TestAuthenticationFlow(t *testing.T) {
115118
require.Equal(t, "https://example.com/oauth", location)
116119

117120
// Step 3: Handle callback
118-
mockToken := &oauth2.Token{AccessToken: "test-token"}
119-
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
120-
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("test-user", nil)
121-
122121
resp, err = client.Get(server.URL + "/.auth/test/callback")
123122
require.NoError(t, err)
124123
resp.Body.Close()
125124

126125
require.Equal(t, http.StatusFound, resp.StatusCode)
127-
128-
// Verify redirect to root
129126
location = resp.Header.Get("Location")
130127
require.Equal(t, "/", location)
131128

132129
// Step 4: Access after authentication
133-
mockProvider.EXPECT().Authorization("test-user").Return(true, nil)
134-
135130
resp, err = client.Get(server.URL + "/")
136131
if err != nil {
137132
t.Fatalf("Request failed: %v", err)
@@ -146,10 +141,15 @@ func TestAuthenticationFlow(t *testing.T) {
146141
defer ctrl.Finish()
147142

148143
// Create mock provider
144+
mockToken := &oauth2.Token{AccessToken: "test-token"}
149145
mockProvider := NewMockProvider(ctrl)
150146
mockProvider.EXPECT().Name().Return("test").AnyTimes()
151147
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
152148
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
149+
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
150+
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
151+
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("unauthorized-user", nil)
152+
mockProvider.EXPECT().Authorization("unauthorized-user").Return(false, nil).AnyTimes()
153153

154154
// Create AuthRouter
155155
authRouter, err := NewAuthRouter(nil, mockProvider)
@@ -167,30 +167,26 @@ func TestAuthenticationFlow(t *testing.T) {
167167
resp.Body.Close()
168168

169169
// Step 2: Start authentication
170-
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
171-
172170
resp, err = client.Get(server.URL + "/.auth/test")
173171
require.NoError(t, err)
174172
resp.Body.Close()
175173

176174
// Step 3: Complete authentication
177-
mockToken := &oauth2.Token{AccessToken: "test-token"}
178-
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
179-
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("unauthorized-user", nil)
180-
181175
resp, err = client.Get(server.URL + "/.auth/test/callback")
182176
require.NoError(t, err)
183177
resp.Body.Close()
184178

185-
// Step 4: Test access when authorization fails
186-
mockProvider.EXPECT().Authorization("unauthorized-user").Return(false, nil)
179+
require.Equal(t, http.StatusForbidden, resp.StatusCode)
187180

181+
// Step 4: Test access when authorization fails
188182
resp, err = client.Get(server.URL + "/")
189183
if err != nil {
190184
t.Fatalf("Request failed: %v", err)
191185
}
192186
defer resp.Body.Close()
193187

194-
require.Equal(t, http.StatusForbidden, resp.StatusCode)
188+
require.Equal(t, http.StatusFound, resp.StatusCode)
189+
location := resp.Header.Get("Location")
190+
require.Equal(t, "/.auth/login", location)
195191
})
196192
}

0 commit comments

Comments
 (0)