Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ func parseAttributeMap(s string) map[string][]string {
return result
}

func parseHeaderMapping(s string) map[string]string {
result := make(map[string]string)
if s == "" {
return result
}
parts := splitWithEscapes(s, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
colonIdx := strings.LastIndex(part, ":")
if colonIdx == -1 {
continue
}
pointer := strings.TrimSpace(part[:colonIdx])
header := strings.TrimSpace(part[colonIdx+1:])
if pointer != "" && header != "" {
result[pointer] = header
}
}
return result
}

type proxyRunnerFunc func(
listen string,
tlsListen string,
Expand Down Expand Up @@ -130,6 +154,7 @@ type proxyRunnerFunc func(
proxyBearerToken string,
proxyTarget []string,
httpStreamingOnly bool,
headerMapping map[string]string,
) error

func main() {
Expand Down Expand Up @@ -174,6 +199,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
var passwordHash string
var proxyBearerToken string
var proxyHeaders string
var headerMapping string
var httpStreamingOnly bool
var trustedProxies string

Expand Down Expand Up @@ -255,6 +281,8 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
}
}

headerMappingMap := parseHeaderMapping(headerMapping)

if err := run(
listen,
tlsListen,
Expand Down Expand Up @@ -294,6 +322,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
proxyBearerToken,
args,
httpStreamingOnly,
headerMappingMap,
); err != nil {
panic(err)
}
Expand Down Expand Up @@ -347,6 +376,7 @@ func newRootCommand(run proxyRunnerFunc) *cobra.Command {
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)")
rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)")
rootCmd.Flags().BoolVar(&httpStreamingOnly, "http-streaming-only", getEnvBoolWithDefault("HTTP_STREAMING_ONLY", false), "Reject SSE (GET) requests and keep the backend in HTTP streaming-only mode")
rootCmd.Flags().StringVar(&headerMapping, "header-mapping", getEnvWithDefault("HEADER_MAPPING", ""), "Comma-separated mapping of userinfo JSON pointer paths to header names (e.g., /email:X-Forwarded-Email,/preferred_username:X-Forwarded-User)")

return rootCmd
}
60 changes: 60 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,64 @@ func TestParseAttributeMap(t *testing.T) {
}
}

func TestParseHeaderMapping(t *testing.T) {
testCases := []struct {
name string
input string
expected map[string]string
}{
{
name: "empty string",
input: "",
expected: map[string]string{},
},
{
name: "single mapping",
input: "/email:X-Forwarded-Email",
expected: map[string]string{
"/email": "X-Forwarded-Email",
},
},
{
name: "multiple mappings",
input: "/email:X-Forwarded-Email,/preferred_username:X-Forwarded-User",
expected: map[string]string{
"/email": "X-Forwarded-Email",
"/preferred_username": "X-Forwarded-User",
},
},
{
name: "nested JSON pointer",
input: "/org/team:X-Forwarded-Team",
expected: map[string]string{
"/org/team": "X-Forwarded-Team",
},
},
{
name: "whitespace trimming",
input: " /email : X-Forwarded-Email , /sub : X-Forwarded-Sub ",
expected: map[string]string{
"/email": "X-Forwarded-Email",
"/sub": "X-Forwarded-Sub",
},
},
{
name: "no colon - skipped",
input: "invalid",
expected: map[string]string{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := parseHeaderMapping(tc.input)
if !reflect.DeepEqual(result, tc.expected) {
t.Errorf("Expected %v, got %v", tc.expected, result)
}
})
}
}

func TestGetEnvWithDefault(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -344,6 +402,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFlag(t *testing.T) {
proxyBearerToken string,
proxyTarget []string,
httpStreamingOnly bool,
headerMapping map[string]string,
) error {
streamingOnly = httpStreamingOnly
receivedTargets = proxyTarget
Expand Down Expand Up @@ -407,6 +466,7 @@ func TestNewRootCommand_HTTPStreamingOnlyFromEnv(t *testing.T) {
proxyBearerToken string,
proxyTarget []string,
httpStreamingOnly bool,
headerMapping map[string]string,
) error {
streamingOnly = httpStreamingOnly
return nil
Expand Down
12 changes: 11 additions & 1 deletion pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"embed"
"encoding/json"
"errors"
"html/template"
"net/http"
Expand Down Expand Up @@ -68,6 +69,8 @@ const (
SessionKeyAuthorized = "authorized"
SessionKeyRedirectURL = "redirect_url"
SessionKeyOAuthState = "oauth_state"
SessionKeyUserID = "user_id"
SessionKeyUserInfo = "user_info"
)

func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
Expand All @@ -87,7 +90,7 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
a.renderError(c, err)
return
}
ok, user, err := provider.Authorization(c, token)
ok, user, userInfo, err := provider.Authorization(c, token)
if err != nil {
a.renderError(c, err)
return
Expand All @@ -97,6 +100,12 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
return
}
session.Set(SessionKeyAuthorized, true)
session.Set(SessionKeyUserID, user)
if userInfo != nil {
if userInfoJSON, err := json.Marshal(userInfo); err == nil {
session.Set(SessionKeyUserInfo, string(userInfoJSON))
}
}
redirectURL := session.Get(SessionKeyRedirectURL)
if redirectURL != nil {
session.Delete(SessionKeyRedirectURL)
Expand Down Expand Up @@ -177,6 +186,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {

session := sessions.Default(c)
session.Set(SessionKeyAuthorized, true)
session.Set(SessionKeyUserID, PasswordUserID)
redirectURL := session.Get(SessionKeyRedirectURL)
if redirectURL != nil {
session.Delete(SessionKeyRedirectURL)
Expand Down
4 changes: 2 additions & 2 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestAuthenticationFlow(t *testing.T) {
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil)
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", map[string]any{"email": "[email protected]"}, nil)

// Create AuthRouter
authRouter, err := NewAuthRouter(nil, false, mockProvider)
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestAuthenticationFlow(t *testing.T) {
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil)
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", map[string]any{"email": "[email protected]"}, nil)

// Create AuthRouter
authRouter, err := NewAuthRouter(nil, false, mockProvider)
Expand Down
40 changes: 20 additions & 20 deletions pkg/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,30 @@ func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token,
return token, nil
}

func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) {
func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) {
client := p.oauth2.Client(ctx, token)
resp1, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user")))
if err != nil {
return false, "", err
return false, "", nil, err
}
if resp1.StatusCode < 200 || resp1.StatusCode >= 300 {
return false, "", errors.New("failed to get user info from GitHub API: " + resp1.Status)
return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp1.Status)
}
defer resp1.Body.Close()

var userInfo struct {
Login string `json:"login"`
}
if err := json.NewDecoder(resp1.Body).Decode(&userInfo); err != nil {
return false, "", err
var userInfoMap map[string]any
if err := json.NewDecoder(resp1.Body).Decode(&userInfoMap); err != nil {
return false, "", nil, err
}

login, _ := userInfoMap["login"].(string)

if len(p.allowedUsers) == 0 && len(p.allowedOrgs) == 0 {
return true, userInfo.Login, nil
return true, login, userInfoMap, nil
}

if slices.Contains(p.allowedUsers, userInfo.Login) {
return true, userInfo.Login, nil
if slices.Contains(p.allowedUsers, login) {
return true, login, userInfoMap, nil
}

allowedOrgTeams := []string{}
Expand All @@ -124,31 +124,31 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token)
if len(allowedOrgs) > 0 {
resp2, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs")))
if err != nil {
return false, "", err
return false, "", nil, err
}
if resp2.StatusCode < 200 || resp2.StatusCode >= 300 {
return false, "", errors.New("failed to get user info from GitHub API: " + resp2.Status)
return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp2.Status)
}
defer resp2.Body.Close()
var orgInfo []struct {
Login string `json:"login"`
}
if err := json.NewDecoder(resp2.Body).Decode(&orgInfo); err != nil {
return false, "", err
return false, "", nil, err
}
for _, o := range orgInfo {
if slices.Contains(allowedOrgs, o.Login) {
return true, userInfo.Login, nil
return true, login, userInfoMap, nil
}
}
}
if len(allowedOrgTeams) > 0 {
resp3, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams")))
if err != nil {
return false, "", err
return false, "", nil, err
}
if resp3.StatusCode < 200 || resp3.StatusCode >= 300 {
return false, "", errors.New("failed to get user info from GitHub API: " + resp3.Status)
return false, "", nil, errors.New("failed to get user info from GitHub API: " + resp3.Status)
}
defer resp3.Body.Close()
var teamInfo []struct {
Expand All @@ -158,14 +158,14 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token)
Slug string `json:"slug"`
}
if err := json.NewDecoder(resp3.Body).Decode(&teamInfo); err != nil {
return false, "", err
return false, "", nil, err
}
for _, team := range teamInfo {
if slices.Contains(allowedOrgTeams, team.Organization.Login+":"+team.Slug) {
return true, userInfo.Login, nil
return true, login, userInfoMap, nil
}
}
}

return false, userInfo.Login, nil
return false, login, userInfoMap, nil
}
2 changes: 1 addition & 1 deletion pkg/auth/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func TestGitHubProviderAuthorization(t *testing.T) {
})

// Call the Authorization method
ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"})
ok, _, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"})
require.NoError(t, err)
require.Equal(t, expect, ok)
})
Expand Down
32 changes: 15 additions & 17 deletions pkg/auth/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,38 +83,36 @@ func (p *googleProvider) Exchange(c *gin.Context, state string) (*oauth2.Token,
return token, nil
}

func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) {
func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, map[string]any, error) {
client := p.oauth2.Client(ctx, token)
resp, err := client.Get(p.userinfoEndpoint)
if err != nil {
return false, "", err
return false, "", nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return false, "", errors.New("failed to get user info from Google API: " + resp.Status)
return false, "", nil, errors.New("failed to get user info from Google API: " + resp.Status)
}
defer resp.Body.Close()

var userInfo struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
HD string `json:"hd"`
}
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return false, "", err
var userInfoMap map[string]any
if err := json.NewDecoder(resp.Body).Decode(&userInfoMap); err != nil {
return false, "", nil, err
}

email, _ := userInfoMap["email"].(string)
hd, _ := userInfoMap["hd"].(string)

if len(p.allowedUsers) == 0 && len(p.allowedWorkspaces) == 0 {
return true, userInfo.Email, nil
return true, email, userInfoMap, nil
}

if slices.Contains(p.allowedUsers, userInfo.Email) {
return true, userInfo.Email, nil
if slices.Contains(p.allowedUsers, email) {
return true, email, userInfoMap, nil
}

if slices.Contains(p.allowedWorkspaces, userInfo.HD) {
return true, userInfo.Email, nil
if slices.Contains(p.allowedWorkspaces, hd) {
return true, email, userInfoMap, nil
}

return false, userInfo.Email, nil
return false, email, userInfoMap, nil
}
Loading
Loading