Skip to content

Commit e847ed5

Browse files
authored
feat: forward authenticated user identity to upstream via headers (#135)
Add --header-mapping flag to inject OIDC/Google/GitHub userinfo attributes into upstream request headers. Userinfo is embedded as JWT custom claims and extracted in the proxy using JSON pointers. Also fixes JWT subject claim from hardcoded "user" to the actual authenticated user's identity. Closes #130
1 parent 9e0af0b commit e847ed5

17 files changed

Lines changed: 324 additions & 76 deletions

main.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,30 @@ func parseAttributeMap(s string) map[string][]string {
9191
return result
9292
}
9393

94+
func parseHeaderMapping(s string) map[string]string {
95+
result := make(map[string]string)
96+
if s == "" {
97+
return result
98+
}
99+
parts := splitWithEscapes(s, ",")
100+
for _, part := range parts {
101+
part = strings.TrimSpace(part)
102+
if part == "" {
103+
continue
104+
}
105+
colonIdx := strings.LastIndex(part, ":")
106+
if colonIdx == -1 {
107+
continue
108+
}
109+
pointer := strings.TrimSpace(part[:colonIdx])
110+
header := strings.TrimSpace(part[colonIdx+1:])
111+
if pointer != "" && header != "" {
112+
result[pointer] = header
113+
}
114+
}
115+
return result
116+
}
117+
94118
type proxyRunnerFunc func(
95119
listen string,
96120
tlsListen string,
@@ -130,6 +154,7 @@ type proxyRunnerFunc func(
130154
proxyBearerToken string,
131155
proxyTarget []string,
132156
httpStreamingOnly bool,
157+
headerMapping map[string]string,
133158
) error
134159

135160
func main() {
@@ -174,6 +199,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
174199
var passwordHash string
175200
var proxyBearerToken string
176201
var proxyHeaders string
202+
var headerMapping string
177203
var httpStreamingOnly bool
178204
var trustedProxies string
179205

@@ -255,6 +281,8 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
255281
}
256282
}
257283

284+
headerMappingMap := parseHeaderMapping(headerMapping)
285+
258286
if err := run(
259287
listen,
260288
tlsListen,
@@ -294,6 +322,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
294322
proxyBearerToken,
295323
args,
296324
httpStreamingOnly,
325+
headerMappingMap,
297326
); err != nil {
298327
panic(err)
299328
}
@@ -347,6 +376,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
347376
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)")
348377
rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)")
349378
rootCmd.Flags().BoolVar(&httpStreamingOnly, "http-streaming-only", getEnvBoolWithDefault("HTTP_STREAMING_ONLY", false), "Reject SSE (GET) requests and keep the backend in HTTP streaming-only mode")
379+
rootCmd.Flags().StringVar(&headerMapping, "header-mapping", getEnvWithDefault("HEADER_MAPPING", ""), "Comma-separated mapping of userinfo JSON pointer paths to header names (e.g., /email:X-Forwarded-Email,/preferred_username:X-Forwarded-User)")
350380

351381
return rootCmd
352382
}

main_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,64 @@ func TestParseAttributeMap(t *testing.T) {
168168
}
169169
}
170170

171+
func TestParseHeaderMapping(t *testing.T) {
172+
testCases := []struct {
173+
name string
174+
input string
175+
expected map[string]string
176+
}{
177+
{
178+
name: "empty string",
179+
input: "",
180+
expected: map[string]string{},
181+
},
182+
{
183+
name: "single mapping",
184+
input: "/email:X-Forwarded-Email",
185+
expected: map[string]string{
186+
"/email": "X-Forwarded-Email",
187+
},
188+
},
189+
{
190+
name: "multiple mappings",
191+
input: "/email:X-Forwarded-Email,/preferred_username:X-Forwarded-User",
192+
expected: map[string]string{
193+
"/email": "X-Forwarded-Email",
194+
"/preferred_username": "X-Forwarded-User",
195+
},
196+
},
197+
{
198+
name: "nested JSON pointer",
199+
input: "/org/team:X-Forwarded-Team",
200+
expected: map[string]string{
201+
"/org/team": "X-Forwarded-Team",
202+
},
203+
},
204+
{
205+
name: "whitespace trimming",
206+
input: " /email : X-Forwarded-Email , /sub : X-Forwarded-Sub ",
207+
expected: map[string]string{
208+
"/email": "X-Forwarded-Email",
209+
"/sub": "X-Forwarded-Sub",
210+
},
211+
},
212+
{
213+
name: "no colon - skipped",
214+
input: "invalid",
215+
expected: map[string]string{},
216+
},
217+
}
218+
219+
for _, tc := range testCases {
220+
t.Run(tc.name, func(t *testing.T) {
221+
result := parseHeaderMapping(tc.input)
222+
if !reflect.DeepEqual(result, tc.expected) {
223+
t.Errorf("Expected %v, got %v", tc.expected, result)
224+
}
225+
})
226+
}
227+
}
228+
171229
func TestGetEnvWithDefault(t *testing.T) {
172230
testCases := []struct {
173231
name string
@@ -344,6 +402,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFlag(t *testing.T) {
344402
proxyBearerToken string,
345403
proxyTarget []string,
346404
httpStreamingOnly bool,
405+
headerMapping map[string]string,
347406
) error {
348407
streamingOnly = httpStreamingOnly
349408
receivedTargets = proxyTarget
@@ -407,6 +466,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFromEnv(t *testing.T) {
407466
proxyBearerToken string,
408467
proxyTarget []string,
409468
httpStreamingOnly bool,
469+
headerMapping map[string]string,
410470
) error {
411471
streamingOnly = httpStreamingOnly
412472
return nil

pkg/auth/auth.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package auth
22

33
import (
44
"embed"
5+
"encoding/json"
56
"errors"
67
"html/template"
78
"net/http"
@@ -68,6 +69,8 @@ const (
6869
SessionKeyAuthorized = "authorized"
6970
SessionKeyRedirectURL = "redirect_url"
7071
SessionKeyOAuthState = "oauth_state"
72+
SessionKeyUserID = "user_id"
73+
SessionKeyUserInfo = "user_info"
7174
)
7275

7376
func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
@@ -87,7 +90,7 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
8790
a.renderError(c, err)
8891
return
8992
}
90-
ok, user, err := provider.Authorization(c, token)
93+
ok, user, userInfo, err := provider.Authorization(c, token)
9194
if err != nil {
9295
a.renderError(c, err)
9396
return
@@ -97,6 +100,12 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
97100
return
98101
}
99102
session.Set(SessionKeyAuthorized, true)
103+
session.Set(SessionKeyUserID, user)
104+
if userInfo != nil {
105+
if userInfoJSON, err := json.Marshal(userInfo); err == nil {
106+
session.Set(SessionKeyUserInfo, string(userInfoJSON))
107+
}
108+
}
100109
redirectURL := session.Get(SessionKeyRedirectURL)
101110
if redirectURL != nil {
102111
session.Delete(SessionKeyRedirectURL)
@@ -177,6 +186,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
177186

178187
session := sessions.Default(c)
179188
session.Set(SessionKeyAuthorized, true)
189+
session.Set(SessionKeyUserID, PasswordUserID)
180190
redirectURL := session.Get(SessionKeyRedirectURL)
181191
if redirectURL != nil {
182192
session.Delete(SessionKeyRedirectURL)

pkg/auth/auth_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func TestAuthenticationFlow(t *testing.T) {
8585
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
8686
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
8787
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
88-
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil)
88+
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", map[string]any{"email": "[email protected]"}, nil)
8989

9090
// Create AuthRouter
9191
authRouter, err := NewAuthRouter(nil, false, mockProvider)
@@ -146,7 +146,7 @@ func TestAuthenticationFlow(t *testing.T) {
146146
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
147147
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
148148
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
149-
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil)
149+
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", map[string]any{"email": "[email protected]"}, nil)
150150

151151
// Create AuthRouter
152152
authRouter, err := NewAuthRouter(nil, false, mockProvider)

pkg/auth/github.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -85,30 +85,30 @@ func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token,
8585
return token, nil
8686
}
8787

88-
func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) {
88+
func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) {
8989
client := p.oauth2.Client(ctx, token)
9090
resp1, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user")))
9191
if err != nil {
92-
return false, "", err
92+
return false, "", nil, err
9393
}
9494
if resp1.StatusCode < 200 || resp1.StatusCode >= 300 {
95-
return false, "", errors.New("failed to get user info from GitHub API: " + resp1.Status)
95+
return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp1.Status)
9696
}
9797
defer resp1.Body.Close()
9898

99-
var userInfo struct {
100-
Login string `json:"login"`
101-
}
102-
if err := json.NewDecoder(resp1.Body).Decode(&userInfo); err != nil {
103-
return false, "", err
99+
var userInfoMap map[string]any
100+
if err := json.NewDecoder(resp1.Body).Decode(&userInfoMap); err != nil {
101+
return false, "", nil, err
104102
}
105103

104+
login, _ := userInfoMap["login"].(string)
105+
106106
if len(p.allowedUsers) == 0 && len(p.allowedOrgs) == 0 {
107-
return true, userInfo.Login, nil
107+
return true, login, userInfoMap, nil
108108
}
109109

110-
if slices.Contains(p.allowedUsers, userInfo.Login) {
111-
return true, userInfo.Login, nil
110+
if slices.Contains(p.allowedUsers, login) {
111+
return true, login, userInfoMap, nil
112112
}
113113

114114
allowedOrgTeams := []string{}
@@ -124,31 +124,31 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token)
124124
if len(allowedOrgs) > 0 {
125125
resp2, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs")))
126126
if err != nil {
127-
return false, "", err
127+
return false, "", nil, err
128128
}
129129
if resp2.StatusCode < 200 || resp2.StatusCode >= 300 {
130-
return false, "", errors.New("failed to get user info from GitHub API: " + resp2.Status)
130+
return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp2.Status)
131131
}
132132
defer resp2.Body.Close()
133133
var orgInfo []struct {
134134
Login string `json:"login"`
135135
}
136136
if err := json.NewDecoder(resp2.Body).Decode(&orgInfo); err != nil {
137-
return false, "", err
137+
return false, "", nil, err
138138
}
139139
for _, o := range orgInfo {
140140
if slices.Contains(allowedOrgs, o.Login) {
141-
return true, userInfo.Login, nil
141+
return true, login, userInfoMap, nil
142142
}
143143
}
144144
}
145145
if len(allowedOrgTeams) > 0 {
146146
resp3, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams")))
147147
if err != nil {
148-
return false, "", err
148+
return false, "", nil, err
149149
}
150150
if resp3.StatusCode < 200 || resp3.StatusCode >= 300 {
151-
return false, "", errors.New("failed to get user info from GitHub API: " + resp3.Status)
151+
return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp3.Status)
152152
}
153153
defer resp3.Body.Close()
154154
var teamInfo []struct {
@@ -158,14 +158,14 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token)
158158
Slug string `json:"slug"`
159159
}
160160
if err := json.NewDecoder(resp3.Body).Decode(&teamInfo); err != nil {
161-
return false, "", err
161+
return false, "", nil, err
162162
}
163163
for _, team := range teamInfo {
164164
if slices.Contains(allowedOrgTeams, team.Organization.Login+":"+team.Slug) {
165-
return true, userInfo.Login, nil
165+
return true, login, userInfoMap, nil
166166
}
167167
}
168168
}
169169

170-
return false, userInfo.Login, nil
170+
return false, login, userInfoMap, nil
171171
}

pkg/auth/github_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func TestGitHubProviderAuthorization(t *testing.T) {
168168
})
169169

170170
// Call the Authorization method
171-
ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"})
171+
ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"})
172172
require.NoError(t, err)
173173
require.Equal(t, expect, ok)
174174
})

pkg/auth/google.go

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -83,38 +83,36 @@ func (p *googleProvider) Exchange(c *gin.Context, state string) (*oauth2.Token,
8383
return token, nil
8484
}
8585

86-
func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) {
86+
func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) {
8787
client := p.oauth2.Client(ctx, token)
8888
resp, err := client.Get(p.userinfoEndpoint)
8989
if err != nil {
90-
return false, "", err
90+
return false, "", nil, err
9191
}
9292
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
93-
return false, "", errors.New("failed to get user info from Google API: " + resp.Status)
93+
return false, "", nil, errors.New("failed to get user info from Google API: " + resp.Status)
9494
}
9595
defer resp.Body.Close()
9696

97-
var userInfo struct {
98-
Sub string `json:"sub"`
99-
Name string `json:"name"`
100-
Email string `json:"email"`
101-
HD string `json:"hd"`
102-
}
103-
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
104-
return false, "", err
97+
var userInfoMap map[string]any
98+
if err := json.NewDecoder(resp.Body).Decode(&userInfoMap); err != nil {
99+
return false, "", nil, err
105100
}
106101

102+
email, _ := userInfoMap["email"].(string)
103+
hd, _ := userInfoMap["hd"].(string)
104+
107105
if len(p.allowedUsers) == 0 && len(p.allowedWorkspaces) == 0 {
108-
return true, userInfo.Email, nil
106+
return true, email, userInfoMap, nil
109107
}
110108

111-
if slices.Contains(p.allowedUsers, userInfo.Email) {
112-
return true, userInfo.Email, nil
109+
if slices.Contains(p.allowedUsers, email) {
110+
return true, email, userInfoMap, nil
113111
}
114112

115-
if slices.Contains(p.allowedWorkspaces, userInfo.HD) {
116-
return true, userInfo.Email, nil
113+
if slices.Contains(p.allowedWorkspaces, hd) {
114+
return true, email, userInfoMap, nil
117115
}
118116

119-
return false, userInfo.Email, nil
117+
return false, email, userInfoMap, nil
120118
}

0 commit comments

Comments
 (0)