Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 41 additions & 18 deletions app/artifact-cas/internal/server/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package server

import (
"context"
"crypto/ecdsa"
"crypto/tls"
"fmt"
"os"
Expand Down Expand Up @@ -52,22 +53,16 @@ import (

// NewGRPCServer new a gRPC server.
func NewGRPCServer(c *conf.Server, authConf *conf.Auth, byteService *service.ByteStreamService, rSvc *service.ResourceService, providers backend.Providers, validator protovalidate.Validator, logger log.Logger) (*grpc.Server, error) {
log := log.NewHelper(logger)
// Load the key on initialization instead of on every request
// Parse the public key once on initialization instead of on every request
// TODO: implement jwks endpoint
publicKeyPath := authConf.GetPublicKeyPath()
if publicKeyPath == "" {
// Maintain backwards compatibility
publicKeyPath = authConf.RobotAccountPublicKeyPath
}

log.Debugw("msg", "loading public key from file", "file", publicKeyPath)

rawKey, err := os.ReadFile(publicKeyPath)
publicKey, err := parsePublicKey(authConf, logger)
if err != nil {
return nil, fmt.Errorf("failed to load public key: %w", err)
return nil, err
}

// Share a single keyfunc closure over the parsed key across all interceptors
keyFunc := loadPublicKey(publicKey)

var opts = []grpc.ServerOption{
// Kratos middleware are in practice unary interceptors
grpc.Middleware(
Expand All @@ -83,7 +78,7 @@ func NewGRPCServer(c *conf.Server, authConf *conf.Auth, byteService *service.Byt
// If we require a logged in user we
selector.Server(
jwtMiddleware.Server(
loadPublicKey(rawKey),
keyFunc,
jwtMiddleware.WithSigningMethod(casJWT.SigningMethod),
jwtMiddleware.WithClaims(func() jwt.Claims { return &casJWT.Claims{} })),
).Match(requireAuthentication()).Build(),
Expand All @@ -92,7 +87,7 @@ func NewGRPCServer(c *conf.Server, authConf *conf.Auth, byteService *service.Byt
// Streaming interceptors
grpc.StreamInterceptor(
grpcselector.StreamServerInterceptor(
grpc_auth.StreamServerInterceptor(jwtAuthFunc(loadPublicKey(rawKey), casJWT.SigningMethod)),
grpc_auth.StreamServerInterceptor(jwtAuthFunc(keyFunc, casJWT.SigningMethod)),
grpcselector.MatchFunc(allButReflectionAPI),
),
// grpc prometheus metrics
Expand Down Expand Up @@ -163,10 +158,38 @@ func allButReflectionAPI(_ context.Context, callMeta interceptors.CallMeta) bool
return !reflectionServiceRegexp.MatchString(callMeta.Service)
}

// load key for verification
func loadPublicKey(rawKey []byte) jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
return jwt.ParseECPublicKeyFromPEM(rawKey)
// parsePublicKey resolves the configured public key path, reads the file and parses
// the EC public key once. A malformed key therefore fails at server construction
// instead of surfacing as a per-request authentication error.
func parsePublicKey(authConf *conf.Auth, logger log.Logger) (*ecdsa.PublicKey, error) {
l := log.NewHelper(logger)

publicKeyPath := authConf.GetPublicKeyPath()
if publicKeyPath == "" {
// Maintain backwards compatibility with the deprecated field.
publicKeyPath = authConf.RobotAccountPublicKeyPath //nolint:staticcheck // intentional fallback to the deprecated field
}

l.Debugw("msg", "loading public key from file", "file", publicKeyPath)

rawKey, err := os.ReadFile(publicKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to load public key: %w", err)
}

publicKey, err := jwt.ParseECPublicKeyFromPEM(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err)
}

return publicKey, nil
}

// loadPublicKey returns a jwt.Keyfunc that hands back the pre-parsed public key,
// avoiding a PEM re-parse on every request.
func loadPublicKey(publicKey *ecdsa.PublicKey) jwt.Keyfunc {
return func(_ *jwt.Token) (interface{}, error) {
return publicKey, nil
}
}

Expand Down
56 changes: 56 additions & 0 deletions app/artifact-cas/internal/server/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import (
"testing"
"time"

"github.com/chainloop-dev/chainloop/app/artifact-cas/internal/conf"
robotaccount "github.com/chainloop-dev/chainloop/internal/robotaccount/cas"
"github.com/go-kratos/kratos/v2/log"
jwtMiddleware "github.com/go-kratos/kratos/v2/middleware/auth/jwt"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
Expand Down Expand Up @@ -181,6 +183,60 @@ func TestAllButReflectionAPI(t *testing.T) {
}
}

func TestLoadPublicKey(t *testing.T) {
rawKey, err := os.ReadFile("./testdata/test-key.ec.pub")
require.NoError(t, err)
want, err := jwt.ParseECPublicKeyFromPEM(rawKey)
require.NoError(t, err)

// The keyfunc must hand back the pre-parsed key without re-parsing the PEM
got, err := loadPublicKey(want)(&jwt.Token{})
require.NoError(t, err)
assert.Same(t, want, got)
}

func TestParsePublicKey(t *testing.T) {
testCases := []struct {
name string
authConf *conf.Auth
wantErr string
}{
{
name: "valid public key via PublicKeyPath",
authConf: &conf.Auth{PublicKeyPath: "./testdata/test-key.ec.pub"},
},
{
name: "valid public key via deprecated RobotAccountPublicKeyPath",
authConf: &conf.Auth{RobotAccountPublicKeyPath: "./testdata/test-key.ec.pub"},
},
{
name: "missing file",
authConf: &conf.Auth{PublicKeyPath: "./testdata/does-not-exist.pub"},
wantErr: "failed to load public key",
},
{
name: "not a public key PEM",
authConf: &conf.Auth{PublicKeyPath: "./testdata/test-key.ec.pem"},
wantErr: "failed to parse public key",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := parsePublicKey(tc.authConf, log.DefaultLogger)
if tc.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tc.wantErr)
assert.Nil(t, got)
return
}

require.NoError(t, err)
assert.NotNil(t, got)
})
}
}

func loadTestPublicKey(path string) jwt.Keyfunc {
rawKey, _ := os.ReadFile(path)
return func(token *jwt.Token) (interface{}, error) {
Expand Down
17 changes: 4 additions & 13 deletions app/artifact-cas/internal/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
package server

import (
"fmt"
"os"

api "github.com/chainloop-dev/chainloop/app/artifact-cas/api/cas/v1"
"github.com/chainloop-dev/chainloop/app/artifact-cas/internal/conf"
"github.com/chainloop-dev/chainloop/app/artifact-cas/internal/service"
Expand Down Expand Up @@ -50,22 +47,16 @@ func NewHTTPServer(c *conf.Server, authConf *conf.Auth, downloadSvc *service.Dow
opts = append(opts, http.Timeout(c.Http.Timeout.AsDuration()))
}

// Load the key on initialization instead of on every request
// Parse the public key once on initialization instead of on every request
// TODO: implement jwks endpoint
publicKeyPath := authConf.GetPublicKeyPath()
if publicKeyPath == "" {
// Maintain backwards compatibility
publicKeyPath = authConf.RobotAccountPublicKeyPath
}

rawKey, err := os.ReadFile(publicKeyPath)
publicKey, err := parsePublicKey(authConf, logger)
if err != nil {
return nil, fmt.Errorf("failed to load public key: %w", err)
return nil, err
}

srv := http.NewServer(opts...)

downloadHandler := middlewares_http.AuthFromQueryParam(loadPublicKey(rawKey), claimsFunc(), casJWT.SigningMethod, downloadSvc)
downloadHandler := middlewares_http.AuthFromQueryParam(loadPublicKey(publicKey), claimsFunc(), casJWT.SigningMethod, downloadSvc)
srv.Handle(service.DownloadPath, CORSMiddleware(c.GetHttp().GetCors().GetAllowOrigins(), downloadHandler))
api.RegisterStatusServiceHTTPServer(srv, service.NewStatusService(Version, providers))
return srv, nil
Expand Down
Loading