Skip to content

Commit 3882d5b

Browse files
committed
fix: improve authentication flow and session handling
- Add authorization check after user authentication - Fix redirect URL handling in session management - Refactor template rendering into dedicated methods - Improve logout behavior to redirect to login page - Clean up session management by deleting specific keys instead of clearing all
1 parent b5736d4 commit 3882d5b

1 file changed

Lines changed: 72 additions & 52 deletions

File tree

pkg/auth/auth.go

Lines changed: 72 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ var templateFS embed.FS
1818
type AuthRouter struct {
1919
passwordHash []string
2020
providers []Provider
21-
template *template.Template
21+
loginTemplate *template.Template
2222
unauthorizedTemplate *template.Template
2323
}
2424

@@ -36,7 +36,7 @@ func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, e
3636
return &AuthRouter{
3737
passwordHash: passwordHash,
3838
providers: providers,
39-
template: tmpl,
39+
loginTemplate: tmpl,
4040
unauthorizedTemplate: unauthorizedTmpl,
4141
}, nil
4242
}
@@ -82,11 +82,29 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
8282
c.Error(err)
8383
return
8484
}
85+
ok, err := provider.Authorization(userID)
86+
if err != nil {
87+
c.Error(err)
88+
return
89+
}
90+
if !ok {
91+
a.renderUnauthorized(c, userID, provider.Name())
92+
return
93+
}
8594
session.Set(SessionKeyProvider, provider.Name())
8695
session.Set(SessionKeyUserID, userID)
87-
session.Save()
8896
redirectURL := session.Get(SessionKeyRedirectURL)
89-
c.Redirect(http.StatusFound, redirectURL.(string))
97+
if redirectURL != nil {
98+
session.Delete(SessionKeyRedirectURL)
99+
}
100+
session.Save()
101+
102+
if redirectURL == nil {
103+
c.Redirect(http.StatusFound, "/")
104+
} else {
105+
c.Redirect(http.StatusFound, redirectURL.(string))
106+
}
107+
90108
})
91109

92110
router.GET(provider.AuthURL(), func(c *gin.Context) {
@@ -118,29 +136,12 @@ func (a *AuthRouter) getProvider(name string) Provider {
118136
return nil
119137
}
120138

121-
type templateData struct {
122-
Providers []Provider
123-
HasPassword bool
124-
PasswordError string
125-
}
126-
127139
func (a *AuthRouter) handleLogin(c *gin.Context) {
128140
if c.Request.Method == "POST" {
129141
a.handleLoginPost(c)
130142
return
131143
}
132-
133-
data := templateData{
134-
Providers: a.providers,
135-
HasPassword: len(a.passwordHash) > 0,
136-
PasswordError: "",
137-
}
138-
139-
c.Header("Content-Type", "text/html; charset=utf-8")
140-
if err := a.template.Execute(c.Writer, data); err != nil {
141-
c.AbortWithError(http.StatusInternalServerError, err)
142-
return
143-
}
144+
a.renderLogin(c, "")
144145
}
145146

146147
func (a *AuthRouter) handleLoginPost(c *gin.Context) {
@@ -165,39 +166,32 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
165166
}
166167

167168
if errorMessage != "" {
168-
data := templateData{
169-
Providers: a.providers,
170-
HasPassword: len(a.passwordHash) > 0,
171-
PasswordError: errorMessage,
172-
}
173-
174-
c.Header("Content-Type", "text/html; charset=utf-8")
175-
c.Status(http.StatusBadRequest)
176-
if err := a.template.Execute(c.Writer, data); err != nil {
177-
c.AbortWithError(http.StatusInternalServerError, err)
178-
return
179-
}
169+
a.renderLogin(c, errorMessage)
180170
return
181171
}
182172

183173
session := sessions.Default(c)
184174
session.Set(SessionKeyProvider, PasswordProvider)
185175
session.Set(SessionKeyUserID, PasswordUserID)
176+
redirectURL := session.Get(SessionKeyRedirectURL)
177+
if redirectURL != nil {
178+
session.Delete(SessionKeyRedirectURL)
179+
}
186180
session.Save()
187181

188-
redirectURL := session.Get(SessionKeyRedirectURL)
189182
if redirectURL == nil {
190183
c.Redirect(http.StatusFound, "/")
191-
return
184+
} else {
185+
c.Redirect(http.StatusFound, redirectURL.(string))
192186
}
193-
c.Redirect(http.StatusFound, redirectURL.(string))
194187
}
195188

196189
func (a *AuthRouter) handleLogout(c *gin.Context) {
197190
session := sessions.Default(c)
198-
session.Clear()
191+
session.Delete(SessionKeyProvider)
192+
session.Delete(SessionKeyUserID)
199193
session.Save()
200-
c.String(http.StatusOK, "Logged out")
194+
c.Redirect(http.StatusFound, LoginEndpoint)
201195
}
202196

203197
func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
@@ -229,22 +223,48 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
229223
return
230224
}
231225
if !ok {
232-
data := struct {
233-
UserID string
234-
Provider string
235-
}{
236-
UserID: userID.(string),
237-
Provider: providerName.(string),
238-
}
239-
c.Header("Content-Type", "text/html; charset=utf-8")
240-
c.Status(http.StatusForbidden)
241-
if err := a.unauthorizedTemplate.Execute(c.Writer, data); err != nil {
242-
c.AbortWithError(http.StatusInternalServerError, err)
243-
return
244-
}
226+
a.renderUnauthorized(c, userID.(string), providerName.(string))
245227
c.Abort()
246228
return
247229
}
248230
c.Next()
249231
}
250232
}
233+
234+
type loginTemplateData struct {
235+
Providers []Provider
236+
HasPassword bool
237+
PasswordError string
238+
}
239+
240+
type unauthorizedTemplateData struct {
241+
UserID string
242+
Provider string
243+
}
244+
245+
func (a *AuthRouter) renderLogin(c *gin.Context, passwordError string) {
246+
data := loginTemplateData{
247+
Providers: a.providers,
248+
HasPassword: len(a.passwordHash) > 0,
249+
PasswordError: passwordError,
250+
}
251+
c.Header("Content-Type", "text/html; charset=utf-8")
252+
c.Status(http.StatusBadRequest)
253+
if err := a.loginTemplate.Execute(c.Writer, data); err != nil {
254+
c.AbortWithError(http.StatusInternalServerError, err)
255+
return
256+
}
257+
}
258+
259+
func (a *AuthRouter) renderUnauthorized(c *gin.Context, userID, providerName string) {
260+
data := unauthorizedTemplateData{
261+
UserID: userID,
262+
Provider: providerName,
263+
}
264+
c.Header("Content-Type", "text/html; charset=utf-8")
265+
c.Status(http.StatusForbidden)
266+
if err := a.unauthorizedTemplate.Execute(c.Writer, data); err != nil {
267+
c.AbortWithError(http.StatusInternalServerError, err)
268+
return
269+
}
270+
}

0 commit comments

Comments
 (0)