@@ -18,7 +18,7 @@ var templateFS embed.FS
1818type AuthRouter struct {
1919 passwordHash []string
2020 providers []Provider
21- template * template.Template
21+ loginTemplate * template.Template
2222 unauthorizedTemplate * template.Template
2323}
2424
@@ -36,7 +36,7 @@ func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, e
3636 return & AuthRouter {
3737 passwordHash : passwordHash ,
3838 providers : providers ,
39- template : tmpl ,
39+ loginTemplate : tmpl ,
4040 unauthorizedTemplate : unauthorizedTmpl ,
4141 }, nil
4242}
@@ -82,11 +82,29 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
8282 c .Error (err )
8383 return
8484 }
85+ ok , err := provider .Authorization (userID )
86+ if err != nil {
87+ c .Error (err )
88+ return
89+ }
90+ if ! ok {
91+ a .renderUnauthorized (c , userID , provider .Name ())
92+ return
93+ }
8594 session .Set (SessionKeyProvider , provider .Name ())
8695 session .Set (SessionKeyUserID , userID )
87- session .Save ()
8896 redirectURL := session .Get (SessionKeyRedirectURL )
89- c .Redirect (http .StatusFound , redirectURL .(string ))
97+ if redirectURL != nil {
98+ session .Delete (SessionKeyRedirectURL )
99+ }
100+ session .Save ()
101+
102+ if redirectURL == nil {
103+ c .Redirect (http .StatusFound , "/" )
104+ } else {
105+ c .Redirect (http .StatusFound , redirectURL .(string ))
106+ }
107+
90108 })
91109
92110 router .GET (provider .AuthURL (), func (c * gin.Context ) {
@@ -118,29 +136,12 @@ func (a *AuthRouter) getProvider(name string) Provider {
118136 return nil
119137}
120138
121- type templateData struct {
122- Providers []Provider
123- HasPassword bool
124- PasswordError string
125- }
126-
127139func (a * AuthRouter ) handleLogin (c * gin.Context ) {
128140 if c .Request .Method == "POST" {
129141 a .handleLoginPost (c )
130142 return
131143 }
132-
133- data := templateData {
134- Providers : a .providers ,
135- HasPassword : len (a .passwordHash ) > 0 ,
136- PasswordError : "" ,
137- }
138-
139- c .Header ("Content-Type" , "text/html; charset=utf-8" )
140- if err := a .template .Execute (c .Writer , data ); err != nil {
141- c .AbortWithError (http .StatusInternalServerError , err )
142- return
143- }
144+ a .renderLogin (c , "" )
144145}
145146
146147func (a * AuthRouter ) handleLoginPost (c * gin.Context ) {
@@ -165,39 +166,32 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
165166 }
166167
167168 if errorMessage != "" {
168- data := templateData {
169- Providers : a .providers ,
170- HasPassword : len (a .passwordHash ) > 0 ,
171- PasswordError : errorMessage ,
172- }
173-
174- c .Header ("Content-Type" , "text/html; charset=utf-8" )
175- c .Status (http .StatusBadRequest )
176- if err := a .template .Execute (c .Writer , data ); err != nil {
177- c .AbortWithError (http .StatusInternalServerError , err )
178- return
179- }
169+ a .renderLogin (c , errorMessage )
180170 return
181171 }
182172
183173 session := sessions .Default (c )
184174 session .Set (SessionKeyProvider , PasswordProvider )
185175 session .Set (SessionKeyUserID , PasswordUserID )
176+ redirectURL := session .Get (SessionKeyRedirectURL )
177+ if redirectURL != nil {
178+ session .Delete (SessionKeyRedirectURL )
179+ }
186180 session .Save ()
187181
188- redirectURL := session .Get (SessionKeyRedirectURL )
189182 if redirectURL == nil {
190183 c .Redirect (http .StatusFound , "/" )
191- return
184+ } else {
185+ c .Redirect (http .StatusFound , redirectURL .(string ))
192186 }
193- c .Redirect (http .StatusFound , redirectURL .(string ))
194187}
195188
196189func (a * AuthRouter ) handleLogout (c * gin.Context ) {
197190 session := sessions .Default (c )
198- session .Clear ()
191+ session .Delete (SessionKeyProvider )
192+ session .Delete (SessionKeyUserID )
199193 session .Save ()
200- c .String (http .StatusOK , "Logged out" )
194+ c .Redirect (http .StatusFound , LoginEndpoint )
201195}
202196
203197func (a * AuthRouter ) RequireAuth () gin.HandlerFunc {
@@ -229,22 +223,48 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
229223 return
230224 }
231225 if ! ok {
232- data := struct {
233- UserID string
234- Provider string
235- }{
236- UserID : userID .(string ),
237- Provider : providerName .(string ),
238- }
239- c .Header ("Content-Type" , "text/html; charset=utf-8" )
240- c .Status (http .StatusForbidden )
241- if err := a .unauthorizedTemplate .Execute (c .Writer , data ); err != nil {
242- c .AbortWithError (http .StatusInternalServerError , err )
243- return
244- }
226+ a .renderUnauthorized (c , userID .(string ), providerName .(string ))
245227 c .Abort ()
246228 return
247229 }
248230 c .Next ()
249231 }
250232}
233+
234+ type loginTemplateData struct {
235+ Providers []Provider
236+ HasPassword bool
237+ PasswordError string
238+ }
239+
240+ type unauthorizedTemplateData struct {
241+ UserID string
242+ Provider string
243+ }
244+
245+ func (a * AuthRouter ) renderLogin (c * gin.Context , passwordError string ) {
246+ data := loginTemplateData {
247+ Providers : a .providers ,
248+ HasPassword : len (a .passwordHash ) > 0 ,
249+ PasswordError : passwordError ,
250+ }
251+ c .Header ("Content-Type" , "text/html; charset=utf-8" )
252+ c .Status (http .StatusBadRequest )
253+ if err := a .loginTemplate .Execute (c .Writer , data ); err != nil {
254+ c .AbortWithError (http .StatusInternalServerError , err )
255+ return
256+ }
257+ }
258+
259+ func (a * AuthRouter ) renderUnauthorized (c * gin.Context , userID , providerName string ) {
260+ data := unauthorizedTemplateData {
261+ UserID : userID ,
262+ Provider : providerName ,
263+ }
264+ c .Header ("Content-Type" , "text/html; charset=utf-8" )
265+ c .Status (http .StatusForbidden )
266+ if err := a .unauthorizedTemplate .Execute (c .Writer , data ); err != nil {
267+ c .AbortWithError (http .StatusInternalServerError , err )
268+ return
269+ }
270+ }
0 commit comments