Skip to content

Commit 0bcddea

Browse files
authored
refactor: restructure package organization and add comprehensive tests (#35)
- Reorganize files to use descriptive names (main.go → package.go) - Move Provider interface to dedicated interface.go files - Add mock generation with go:generate directives - Implement comprehensive unit tests for all packages: - auth: OAuth flow, authentication middleware, authorization checks - backend: MCP proxy setup, command execution, test server - idp: OAuth server metadata, JWKS, client registration and flow - proxy: JWT validation, token extraction, proxy handling - Add testify for assertions and testserver for backend testing - Update dependencies: add testify and memstore for session testing - Delete obsolete mock files and consolidate utility functions
1 parent 1268a0e commit 0bcddea

18 files changed

Lines changed: 844 additions & 493 deletions

File tree

go.mod

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ require (
1111
github.com/mark3labs/mcp-go v0.37.0
1212
github.com/ory/fosite v0.49.0
1313
github.com/spf13/cobra v1.8.1
14+
github.com/stretchr/testify v1.10.0
1415
go.etcd.io/bbolt v1.4.2
1516
go.uber.org/mock v0.5.2
1617
go.uber.org/zap v1.27.0
@@ -77,14 +78,14 @@ require (
7778
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
7879
github.com/pkg/errors v0.9.1 // indirect
7980
github.com/pmezard/go-difflib v1.0.0 // indirect
81+
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
8082
github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210414080842-5b05eb8ff761 // indirect
8183
github.com/sirupsen/logrus v1.9.3 // indirect
8284
github.com/spf13/afero v1.9.5 // indirect
8385
github.com/spf13/cast v1.7.1 // indirect
8486
github.com/spf13/jwalterweatherman v1.1.0 // indirect
8587
github.com/spf13/pflag v1.0.6 // indirect
8688
github.com/spf13/viper v1.16.0 // indirect
87-
github.com/stretchr/testify v1.10.0 // indirect
8889
github.com/subosito/gotenv v1.4.2 // indirect
8990
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
9091
github.com/ugorji/go/codec v1.2.12 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qR
406406
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
407407
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
408408
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
409+
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc=
410+
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
409411
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
410412
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
411413
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=

pkg/auth/main.go renamed to pkg/auth/auth.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package auth
22

33
import (
4-
"context"
54
"embed"
65
"errors"
76
"html/template"
@@ -11,22 +10,11 @@ import (
1110
"github.com/gin-gonic/gin"
1211
"github.com/sigbit/mcp-auth-proxy/pkg/utils"
1312
"golang.org/x/crypto/bcrypt"
14-
"golang.org/x/oauth2"
1513
)
1614

1715
//go:embed templates/*
1816
var templateFS embed.FS
1917

20-
type Provider interface {
21-
Name() string
22-
RedirectURL() string
23-
AuthURL() string
24-
AuthCodeURL(c *gin.Context, state string) (string, error)
25-
Exchange(c *gin.Context, state string) (*oauth2.Token, error)
26-
GetUserID(ctx context.Context, token *oauth2.Token) (string, error)
27-
Authorization(userid string) (bool, error)
28-
}
29-
3018
type AuthRouter struct {
3119
passwordHash []string
3220
providers map[string]Provider

pkg/auth/auth_test.go

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package auth
2+
3+
import (
4+
"net/http"
5+
"net/http/cookiejar"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/gin-contrib/sessions"
10+
"github.com/gin-contrib/sessions/memstore"
11+
"github.com/gin-gonic/gin"
12+
"github.com/stretchr/testify/require"
13+
"go.uber.org/mock/gomock"
14+
"golang.org/x/oauth2"
15+
)
16+
17+
func setupTestRouter(authRouter *AuthRouter) *gin.Engine {
18+
gin.SetMode(gin.TestMode)
19+
router := gin.New()
20+
21+
// Setup session middleware
22+
store := memstore.NewStore([]byte("test-secret"))
23+
router.Use(sessions.Sessions("session", store))
24+
25+
// Setup dummy protected route
26+
router.GET("/", authRouter.RequireAuth(), func(c *gin.Context) {
27+
c.String(http.StatusOK, "authenticated")
28+
})
29+
30+
// Setup authentication routes
31+
authRouter.SetupRoutes(router)
32+
33+
return router
34+
}
35+
36+
func setupClient() *http.Client {
37+
jar, _ := cookiejar.New(nil)
38+
return &http.Client{
39+
Jar: jar,
40+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
41+
return http.ErrUseLastResponse
42+
},
43+
}
44+
}
45+
46+
func TestAuthenticationFlow(t *testing.T) {
47+
t.Run("Unauthenticated access should redirect to login", func(t *testing.T) {
48+
ctrl := gomock.NewController(t)
49+
defer ctrl.Finish()
50+
51+
// Create mock provider
52+
mockProvider := NewMockProvider(ctrl)
53+
mockProvider.EXPECT().Name().Return("test").AnyTimes()
54+
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
55+
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
56+
57+
// Create AuthRouter
58+
authRouter, err := NewAuthRouter(nil, mockProvider)
59+
require.NoError(t, err)
60+
61+
router := setupTestRouter(authRouter)
62+
server := httptest.NewServer(router)
63+
defer server.Close()
64+
65+
client := setupClient()
66+
67+
resp, err := client.Get(server.URL + "/")
68+
require.NoError(t, err)
69+
defer resp.Body.Close()
70+
71+
require.Equal(t, http.StatusFound, resp.StatusCode)
72+
73+
location := resp.Header.Get("Location")
74+
require.Equal(t, LoginEndpoint, location)
75+
})
76+
77+
t.Run("OAuth authentication flow", func(t *testing.T) {
78+
ctrl := gomock.NewController(t)
79+
defer ctrl.Finish()
80+
81+
// Create mock provider
82+
mockProvider := NewMockProvider(ctrl)
83+
mockProvider.EXPECT().Name().Return("test").AnyTimes()
84+
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
85+
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
86+
87+
// Create AuthRouter
88+
authRouter, err := NewAuthRouter(nil, mockProvider)
89+
require.NoError(t, err)
90+
91+
router := setupTestRouter(authRouter)
92+
server := httptest.NewServer(router)
93+
defer server.Close()
94+
95+
client := setupClient()
96+
97+
// Step 1: Access unauthenticated route first to set redirectURL in session
98+
resp, err := client.Get(server.URL + "/")
99+
require.NoError(t, err)
100+
resp.Body.Close()
101+
102+
// Verify redirect to login page
103+
require.Equal(t, http.StatusFound, resp.StatusCode)
104+
105+
// Step 2: Start authentication
106+
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
107+
108+
resp, err = client.Get(server.URL + "/.auth/test")
109+
require.NoError(t, err)
110+
resp.Body.Close()
111+
112+
require.Equal(t, http.StatusFound, resp.StatusCode)
113+
114+
location := resp.Header.Get("Location")
115+
require.Equal(t, "https://example.com/oauth", location)
116+
117+
// Step 3: Handle callback
118+
mockToken := &oauth2.Token{AccessToken: "test-token"}
119+
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
120+
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("test-user", nil)
121+
122+
resp, err = client.Get(server.URL + "/.auth/test/callback")
123+
require.NoError(t, err)
124+
resp.Body.Close()
125+
126+
require.Equal(t, http.StatusFound, resp.StatusCode)
127+
128+
// Verify redirect to root
129+
location = resp.Header.Get("Location")
130+
require.Equal(t, "/", location)
131+
132+
// Step 4: Access after authentication
133+
mockProvider.EXPECT().Authorization("test-user").Return(true, nil)
134+
135+
resp, err = client.Get(server.URL + "/")
136+
if err != nil {
137+
t.Fatalf("Request failed: %v", err)
138+
}
139+
defer resp.Body.Close()
140+
141+
require.Equal(t, http.StatusOK, resp.StatusCode)
142+
})
143+
144+
t.Run("Unauthorized user should be blocked", func(t *testing.T) {
145+
ctrl := gomock.NewController(t)
146+
defer ctrl.Finish()
147+
148+
// Create mock provider
149+
mockProvider := NewMockProvider(ctrl)
150+
mockProvider.EXPECT().Name().Return("test").AnyTimes()
151+
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
152+
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
153+
154+
// Create AuthRouter
155+
authRouter, err := NewAuthRouter(nil, mockProvider)
156+
require.NoError(t, err)
157+
158+
router := setupTestRouter(authRouter)
159+
server := httptest.NewServer(router)
160+
defer server.Close()
161+
162+
client := setupClient()
163+
164+
// Step 1: Access unauthenticated route first
165+
resp, err := client.Get(server.URL + "/")
166+
require.NoError(t, err)
167+
resp.Body.Close()
168+
169+
// Step 2: Start authentication
170+
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
171+
172+
resp, err = client.Get(server.URL + "/.auth/test")
173+
require.NoError(t, err)
174+
resp.Body.Close()
175+
176+
// Step 3: Complete authentication
177+
mockToken := &oauth2.Token{AccessToken: "test-token"}
178+
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
179+
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("unauthorized-user", nil)
180+
181+
resp, err = client.Get(server.URL + "/.auth/test/callback")
182+
require.NoError(t, err)
183+
resp.Body.Close()
184+
185+
// Step 4: Test access when authorization fails
186+
mockProvider.EXPECT().Authorization("unauthorized-user").Return(false, nil)
187+
188+
resp, err = client.Get(server.URL + "/")
189+
if err != nil {
190+
t.Fatalf("Request failed: %v", err)
191+
}
192+
defer resp.Body.Close()
193+
194+
require.Equal(t, http.StatusForbidden, resp.StatusCode)
195+
})
196+
}

pkg/auth/interface.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//go:generate mockgen -source=interface.go -destination=mock.go -package=auth
2+
package auth
3+
4+
import (
5+
"context"
6+
7+
"github.com/gin-gonic/gin"
8+
"golang.org/x/oauth2"
9+
)
10+
11+
type Provider interface {
12+
Name() string
13+
RedirectURL() string
14+
AuthURL() string
15+
AuthCodeURL(c *gin.Context, state string) (string, error)
16+
Exchange(c *gin.Context, state string) (*oauth2.Token, error)
17+
GetUserID(ctx context.Context, token *oauth2.Token) (string, error)
18+
Authorization(userid string) (bool, error)
19+
}

0 commit comments

Comments
 (0)