Skip to content

Commit 36b3d75

Browse files
Merge commit from fork
* Fix injection issues * Add exclusion /debug/vars key; consolidate func; apply to bulk and live loader too Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
1 parent 9afc917 commit 36b3d75

9 files changed

Lines changed: 363 additions & 28 deletions

File tree

dgraph/cmd/alpha/http_test.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -838,24 +838,39 @@ func TestHealth(t *testing.T) {
838838
require.True(t, info[0].Uptime > int64(time.Duration(1)))
839839
}
840840

841-
// TestPprofCmdlineNotExposed ensures that /debug/pprof/cmdline is not reachable
842-
// without authentication. The endpoint exposes the full process command line,
843-
// which may include the admin token passed via --security "token=...".
844-
// The other pprof sub-endpoints should remain accessible.
845-
func TestPprofCmdlineNotExposed(t *testing.T) {
846-
// cmdline must be blocked — it leaks the admin token from process args.
841+
// TestCmdlineEndpointsNotExposed ensures that endpoints which expose the full
842+
// process command line are not reachable without authentication. Both
843+
// /debug/pprof/cmdline (net/http/pprof) and /debug/vars (expvar, which
844+
// publishes os.Args as "cmdline") can leak the admin token passed via
845+
// --security "token=...".
846+
func TestCmdlineEndpointsNotExposed(t *testing.T) {
847+
// /debug/pprof/cmdline must be blocked.
847848
resp, err := http.Get(fmt.Sprintf("%s/debug/pprof/cmdline", addr))
848849
require.NoError(t, err)
849850
defer resp.Body.Close()
850851
require.Equal(t, http.StatusNotFound, resp.StatusCode,
851852
"/debug/pprof/cmdline should return 404; got %d", resp.StatusCode)
852853

853-
// Sanity-check that other pprof endpoints are still reachable.
854-
resp2, err := http.Get(fmt.Sprintf("%s/debug/pprof/heap", addr))
854+
// /debug/vars must still be reachable but must NOT include "cmdline".
855+
resp2, err := http.Get(fmt.Sprintf("%s/debug/vars", addr))
855856
require.NoError(t, err)
856857
defer resp2.Body.Close()
857858
require.Equal(t, http.StatusOK, resp2.StatusCode,
858-
"/debug/pprof/heap should return 200; got %d", resp2.StatusCode)
859+
"/debug/vars should return 200; got %d", resp2.StatusCode)
860+
body, err := io.ReadAll(resp2.Body)
861+
require.NoError(t, err)
862+
var vars map[string]json.RawMessage
863+
require.NoError(t, json.Unmarshal(body, &vars))
864+
_, hasCmdline := vars["cmdline"]
865+
require.False(t, hasCmdline,
866+
"/debug/vars response must not contain the cmdline key")
867+
868+
// Sanity-check that other pprof endpoints are still reachable.
869+
resp3, err := http.Get(fmt.Sprintf("%s/debug/pprof/heap", addr))
870+
require.NoError(t, err)
871+
defer resp3.Body.Close()
872+
require.Equal(t, http.StatusOK, resp3.StatusCode,
873+
"/debug/pprof/heap should return 200; got %d", resp3.StatusCode)
859874
}
860875

861876
func setDrainingMode(t *testing.T, enable bool, accessJwt string) {

dgraph/cmd/alpha/run.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -587,18 +587,7 @@ func setupServer(closer *z.Closer) {
587587
x.ServerCloser.AddRunning(3)
588588
go serveGRPC(grpcListener, tlsCfg, x.ServerCloser)
589589

590-
// Block /debug/pprof/cmdline — importing net/http/pprof registers it on
591-
// http.DefaultServeMux, but it exposes the full process command line which
592-
// may include the admin token from --security "token=...".
593-
serverHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
594-
if r.URL.Path == "/debug/pprof/cmdline" {
595-
http.NotFound(w, r)
596-
return
597-
}
598-
http.DefaultServeMux.ServeHTTP(w, r)
599-
})
600-
go x.StartListenHttpAndHttps(httpListener, tlsCfg, x.ServerCloser, serverHandler)
601-
590+
serverHandler := x.SanitizedDefaultServeMux()
602591
go x.StartListenHttpAndHttps(httpListener, tlsCfg, x.ServerCloser, serverHandler)
603592

604593
go func() {

dgraph/cmd/alpha/upsert_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3012,3 +3012,93 @@ func TestLargeStringIndex(t *testing.T) {
30123012
require.Contains(t, dqlSchema,
30133013
`{"predicate":"name_term","type":"string","index":true,"tokenizer":["term"]}`)
30143014
}
3015+
3016+
// TestDQLInjectionViaCondField is a security regression test for LEAD-001.
3017+
// It verifies that a crafted cond field containing an injected DQL query block
3018+
// is rejected by input validation before reaching the query builder.
3019+
func TestDQLInjectionViaCondField(t *testing.T) {
3020+
require.NoError(t, dropAll())
3021+
require.NoError(t, alterSchema(`
3022+
name: string @index(exact) .
3023+
email: string @index(exact) .
3024+
secret: string .
3025+
`))
3026+
3027+
// Seed data that should NOT be readable via a mutation request.
3028+
seed := `{
3029+
set {
3030+
_:u1 <dgraph.type> "User" .
3031+
_:u1 <name> "Alice" .
3032+
_:u1 <email> "[email protected]" .
3033+
_:u1 <secret> "SSN-111-22-3333" .
3034+
3035+
_:u2 <dgraph.type> "User" .
3036+
_:u2 <name> "Bob" .
3037+
_:u2 <email> "[email protected]" .
3038+
_:u2 <secret> "API_KEY_secret_abc123" .
3039+
}
3040+
}`
3041+
_, err := mutationWithTs(mutationInp{body: seed, typ: "application/rdf", commitNow: true})
3042+
require.NoError(t, err)
3043+
3044+
// Craft the injection payload. The cond value closes the @if() clause and
3045+
// appends an entirely new named query block "leak" that would exfiltrate all
3046+
// data if the injection were not blocked.
3047+
injectionPayload := `{
3048+
"query": "{ q(func: uid(0x1)) { uid } }",
3049+
"mutations": [{
3050+
"set": [{"uid": "0x1", "dgraph.type": "Dummy"}],
3051+
"cond": "@if(eq(name, \"nonexistent\"))\n leak(func: has(dgraph.type)) { uid dgraph.type name email secret }"
3052+
}]
3053+
}`
3054+
3055+
// The injection payload must be rejected by cond validation.
3056+
_, err = mutationWithTs(mutationInp{
3057+
body: injectionPayload,
3058+
typ: "application/json",
3059+
commitNow: true,
3060+
})
3061+
require.Error(t, err)
3062+
require.Contains(t, err.Error(), "invalid cond value")
3063+
3064+
// Verify that no mutation was applied — the request was rejected entirely.
3065+
q := `{ q(func: has(dgraph.type)) { uid name } }`
3066+
res, _, err := queryWithTs(queryInp{body: q, typ: "application/dql"})
3067+
require.NoError(t, err)
3068+
require.NotContains(t, res, "Dummy")
3069+
3070+
// Verify that a legitimate conditional upsert still works.
3071+
legitimateUpsert := `{
3072+
"query": "{ q(func: eq(name, \"Alice\")) { v as uid } }",
3073+
"mutations": [{
3074+
"set": [{"uid": "uid(v)", "email": "[email protected]"}],
3075+
"cond": "@if(eq(len(v), 1))"
3076+
}]
3077+
}`
3078+
_, err = mutationWithTs(mutationInp{
3079+
body: legitimateUpsert,
3080+
typ: "application/json",
3081+
commitNow: true,
3082+
})
3083+
require.NoError(t, err)
3084+
}
3085+
3086+
func TestStringWithQuote(t *testing.T) {
3087+
require.NoError(t, dropAll())
3088+
require.NoError(t, alterSchemaWithRetry(`name: string @unique @index(exact) .`))
3089+
mu := `{ set { <0x01> <name> "\"problem\" is the quotes (json)" . } }`
3090+
require.NoError(t, runMutation(mu))
3091+
3092+
var data struct {
3093+
Data struct {
3094+
Q []struct {
3095+
Name string `json:"name"`
3096+
} `json:"q"`
3097+
} `json:"data"`
3098+
}
3099+
q := `{ q(func: has(name)) { name } }`
3100+
res, _, err := queryWithTs(queryInp{body: q, typ: "application/dql"})
3101+
require.NoError(t, err)
3102+
require.NoError(t, json.Unmarshal([]byte(res), &data))
3103+
require.Equal(t, `"problem" is the quotes (json)`, data.Data.Q[0].Name)
3104+
}

dgraph/cmd/bulk/run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func run() {
271271
maxOpenFilesWarning()
272272

273273
go func() {
274-
log.Fatal(http.ListenAndServe(opt.HttpAddr, nil))
274+
log.Fatal(http.ListenAndServe(opt.HttpAddr, x.SanitizedDefaultServeMux()))
275275
}()
276276
http.HandleFunc("/jemalloc", x.JemallocHandler)
277277

dgraph/cmd/debug/run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,7 @@ func run() {
955955
go func() {
956956
for i := 8080; i < 9080; i++ {
957957
fmt.Printf("Listening for /debug HTTP requests at port: %d\n", i)
958-
if err := http.ListenAndServe(fmt.Sprintf("localhost:%d", i), nil); err != nil {
958+
if err := http.ListenAndServe(fmt.Sprintf("localhost:%d", i), x.SanitizedDefaultServeMux()); err != nil {
959959
fmt.Println("Port busy. Trying another one...")
960960
continue
961961
}

dgraph/cmd/live/run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ func run() error {
751751
z.SetTmpDir(opt.tmpDir)
752752

753753
go func() {
754-
if err := http.ListenAndServe(opt.httpAddr, nil); err != nil {
754+
if err := http.ListenAndServe(opt.httpAddr, x.SanitizedDefaultServeMux()); err != nil {
755755
glog.Errorf("Error while starting HTTP server: %+v", err)
756756
}
757757
}()

edgraph/server.go

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"fmt"
2525
"math"
2626
"net"
27+
"regexp"
2728
"sort"
2829
"strconv"
2930
"strings"
@@ -744,11 +745,102 @@ func validateMutation(ctx context.Context, edges []*pb.DirectedEdge) error {
744745
return nil
745746
}
746747

748+
// validateCondValue checks that a cond string is a well-formed @if(...) or @filter(...)
749+
// clause with balanced parentheses and no trailing content. This prevents DQL injection
750+
// via crafted cond values that close the parenthesized expression and append additional
751+
// query blocks.
752+
func validateCondValue(cond string) error {
753+
cond = strings.TrimSpace(cond)
754+
if cond == "" {
755+
return nil
756+
}
757+
758+
lower := strings.ToLower(cond)
759+
if !strings.HasPrefix(lower, "@if(") && !strings.HasPrefix(lower, "@filter(") {
760+
return errors.Errorf("invalid cond value: must start with @if( or @filter(")
761+
}
762+
763+
openIdx := strings.Index(cond, "(")
764+
if openIdx == -1 {
765+
return errors.Errorf("invalid cond value: missing opening parenthesis")
766+
}
767+
768+
depth := 0
769+
inString := false
770+
escaped := false
771+
closingIdx := -1
772+
773+
for i := openIdx; i < len(cond); i++ {
774+
ch := cond[i]
775+
if escaped {
776+
escaped = false
777+
continue
778+
}
779+
if ch == '\\' {
780+
escaped = true
781+
continue
782+
}
783+
if ch == '"' {
784+
inString = !inString
785+
continue
786+
}
787+
if inString {
788+
continue
789+
}
790+
if ch == '(' {
791+
depth++
792+
} else if ch == ')' {
793+
depth--
794+
if depth == 0 {
795+
closingIdx = i
796+
break
797+
}
798+
}
799+
}
800+
801+
if closingIdx == -1 {
802+
return errors.Errorf("invalid cond value: unbalanced parentheses")
803+
}
804+
805+
trailing := strings.TrimSpace(cond[closingIdx+1:])
806+
if trailing != "" {
807+
return errors.Errorf("invalid cond value: unexpected content after condition")
808+
}
809+
810+
return nil
811+
}
812+
813+
// valVarRegexp matches a valid val(variableName) reference used in upsert mutations.
814+
var valVarRegexp = regexp.MustCompile(`^val\([a-zA-Z_][a-zA-Z0-9_.]*\)$`)
815+
816+
// validateValObjectId checks that an ObjectId starting with "val(" is a well-formed
817+
// val(variableName) reference and contains no injected DQL syntax.
818+
func validateValObjectId(objectId string) error {
819+
if !valVarRegexp.MatchString(objectId) {
820+
return errors.Errorf("invalid val() reference in ObjectId: %q", objectId)
821+
}
822+
return nil
823+
}
824+
825+
// langTagRegexp matches a valid BCP 47 language tag (letters, digits, hyphens).
826+
var langTagRegexp = regexp.MustCompile(`^[a-zA-Z]+(-[a-zA-Z0-9]+)*$`)
827+
828+
// validateLangTag checks that a language tag contains only safe characters.
829+
func validateLangTag(lang string) error {
830+
if lang == "" {
831+
return nil
832+
}
833+
if !langTagRegexp.MatchString(lang) {
834+
return errors.Errorf("invalid language tag: %q", lang)
835+
}
836+
return nil
837+
}
838+
747839
// buildUpsertQuery modifies the query to evaluate the
748840
// @if condition defined in Conditional Upsert.
749-
func buildUpsertQuery(qc *queryContext) string {
841+
func buildUpsertQuery(qc *queryContext) (string, error) {
750842
if qc.req.Query == "" || len(qc.gmuList) == 0 {
751-
return qc.req.Query
843+
return qc.req.Query, nil
752844
}
753845

754846
qc.condVars = make([]string, len(qc.req.Mutations))
@@ -759,6 +851,10 @@ func buildUpsertQuery(qc *queryContext) string {
759851
for i, gmu := range qc.gmuList {
760852
isCondUpsert := strings.TrimSpace(gmu.Cond) != ""
761853
if isCondUpsert {
854+
if err := validateCondValue(gmu.Cond); err != nil {
855+
return "", err
856+
}
857+
762858
qc.condVars[i] = fmt.Sprintf("__dgraph_upsertcheck_%v__", strconv.Itoa(i))
763859
qc.uidRes[qc.condVars[i]] = nil
764860
// @if in upsert is same as @filter in the query
@@ -788,7 +884,7 @@ func buildUpsertQuery(qc *queryContext) string {
788884
}
789885

790886
x.Check2(upsertQB.WriteString(`}`))
791-
return upsertQB.String()
887+
return upsertQB.String(), nil
792888
}
793889

794890
// updateMutations updates the mutation and replaces uid(var) and val(var) with
@@ -1599,7 +1695,12 @@ func parseRequest(ctx context.Context, qc *queryContext) error {
15991695

16001696
qc.uidRes = make(map[string][]string)
16011697
qc.valRes = make(map[string]map[uint64]types.Val)
1602-
upsertQuery = buildUpsertQuery(qc)
1698+
var err error
1699+
upsertQuery, err = buildUpsertQuery(qc)
1700+
if err != nil {
1701+
return err
1702+
}
1703+
16031704
needVars = findMutationVars(qc)
16041705
if upsertQuery == "" {
16051706
if len(needVars) > 0 {
@@ -1746,6 +1847,9 @@ func addQueryIfUnique(qctx context.Context, qc *queryContext) error {
17461847
// during the automatic serialization of a structure into JSON.
17471848
predicateName := fmt.Sprintf("<%v>", pred.Predicate)
17481849
if pred.Lang != "" {
1850+
if err := validateLangTag(pred.Lang); err != nil {
1851+
return err
1852+
}
17491853
predicateName = fmt.Sprintf("%v@%v", predicateName, pred.Lang)
17501854
}
17511855

@@ -1780,6 +1884,9 @@ func addQueryIfUnique(qctx context.Context, qc *queryContext) error {
17801884
}
17811885
qc.uniqueVars[uniqueVarMapKey] = uniquePredMeta{queryVar: queryVar}
17821886
} else {
1887+
if err := validateValObjectId(pred.ObjectId); err != nil {
1888+
return err
1889+
}
17831890
valQueryVar := fmt.Sprintf("__dgraph_uniquecheck_val_%v__", uniqueVarMapKey)
17841891
query := fmt.Sprintf(`%v as var(func: eq(%v,%v)){
17851892
uid

0 commit comments

Comments
 (0)