@@ -13,6 +13,7 @@ import (
1313 "fmt"
1414 "math"
1515 "net"
16+ "regexp"
1617 "sort"
1718 "strconv"
1819 "strings"
@@ -709,11 +710,102 @@ func validateMutation(ctx context.Context, edges []*pb.DirectedEdge) error {
709710 return nil
710711}
711712
713+ // validateCondValue checks that a cond string is a well-formed @if(...) or @filter(...)
714+ // clause with balanced parentheses and no trailing content. This prevents DQL injection
715+ // via crafted cond values that close the parenthesized expression and append additional
716+ // query blocks.
717+ func validateCondValue (cond string ) error {
718+ cond = strings .TrimSpace (cond )
719+ if cond == "" {
720+ return nil
721+ }
722+
723+ lower := strings .ToLower (cond )
724+ if ! strings .HasPrefix (lower , "@if(" ) && ! strings .HasPrefix (lower , "@filter(" ) {
725+ return errors .Errorf ("invalid cond value: must start with @if( or @filter(" )
726+ }
727+
728+ openIdx := strings .Index (cond , "(" )
729+ if openIdx == - 1 {
730+ return errors .Errorf ("invalid cond value: missing opening parenthesis" )
731+ }
732+
733+ depth := 0
734+ inString := false
735+ escaped := false
736+ closingIdx := - 1
737+
738+ for i := openIdx ; i < len (cond ); i ++ {
739+ ch := cond [i ]
740+ if escaped {
741+ escaped = false
742+ continue
743+ }
744+ if ch == '\\' {
745+ escaped = true
746+ continue
747+ }
748+ if ch == '"' {
749+ inString = ! inString
750+ continue
751+ }
752+ if inString {
753+ continue
754+ }
755+ if ch == '(' {
756+ depth ++
757+ } else if ch == ')' {
758+ depth --
759+ if depth == 0 {
760+ closingIdx = i
761+ break
762+ }
763+ }
764+ }
765+
766+ if closingIdx == - 1 {
767+ return errors .Errorf ("invalid cond value: unbalanced parentheses" )
768+ }
769+
770+ trailing := strings .TrimSpace (cond [closingIdx + 1 :])
771+ if trailing != "" {
772+ return errors .Errorf ("invalid cond value: unexpected content after condition" )
773+ }
774+
775+ return nil
776+ }
777+
778+ // valVarRegexp matches a valid val(variableName) reference used in upsert mutations.
779+ var valVarRegexp = regexp .MustCompile (`^val\([a-zA-Z_][a-zA-Z0-9_.]*\)$` )
780+
781+ // validateValObjectId checks that an ObjectId starting with "val(" is a well-formed
782+ // val(variableName) reference and contains no injected DQL syntax.
783+ func validateValObjectId (objectId string ) error {
784+ if ! valVarRegexp .MatchString (objectId ) {
785+ return errors .Errorf ("invalid val() reference in ObjectId: %q" , objectId )
786+ }
787+ return nil
788+ }
789+
790+ // langTagRegexp matches a valid BCP 47 language tag (letters, digits, hyphens).
791+ var langTagRegexp = regexp .MustCompile (`^[a-zA-Z]+(-[a-zA-Z0-9]+)*$` )
792+
793+ // validateLangTag checks that a language tag contains only safe characters.
794+ func validateLangTag (lang string ) error {
795+ if lang == "" {
796+ return nil
797+ }
798+ if ! langTagRegexp .MatchString (lang ) {
799+ return errors .Errorf ("invalid language tag: %q" , lang )
800+ }
801+ return nil
802+ }
803+
712804// buildUpsertQuery modifies the query to evaluate the
713805// @if condition defined in Conditional Upsert.
714- func buildUpsertQuery (qc * queryContext ) string {
806+ func buildUpsertQuery (qc * queryContext ) ( string , error ) {
715807 if qc .req .Query == "" || len (qc .gmuList ) == 0 {
716- return qc .req .Query
808+ return qc .req .Query , nil
717809 }
718810
719811 qc .condVars = make ([]string , len (qc .req .Mutations ))
@@ -724,6 +816,10 @@ func buildUpsertQuery(qc *queryContext) string {
724816 for i , gmu := range qc .gmuList {
725817 isCondUpsert := strings .TrimSpace (gmu .Cond ) != ""
726818 if isCondUpsert {
819+ if err := validateCondValue (gmu .Cond ); err != nil {
820+ return "" , err
821+ }
822+
727823 qc .condVars [i ] = fmt .Sprintf ("__dgraph_upsertcheck_%v__" , strconv .Itoa (i ))
728824 qc.uidRes [qc.condVars [i ]] = nil
729825 // @if in upsert is same as @filter in the query
@@ -753,7 +849,7 @@ func buildUpsertQuery(qc *queryContext) string {
753849 }
754850
755851 x .Check2 (upsertQB .WriteString (`}` ))
756- return upsertQB .String ()
852+ return upsertQB .String (), nil
757853}
758854
759855// updateMutations updates the mutation and replaces uid(var) and val(var) with
@@ -1581,7 +1677,11 @@ func parseRequest(ctx context.Context, qc *queryContext) error {
15811677
15821678 qc .uidRes = make (map [string ][]string )
15831679 qc .valRes = make (map [string ]* types.ShardedMap )
1584- upsertQuery = buildUpsertQuery (qc )
1680+ var err error
1681+ upsertQuery , err = buildUpsertQuery (qc )
1682+ if err != nil {
1683+ return err
1684+ }
15851685 needVars = findMutationVars (qc )
15861686 if upsertQuery == "" {
15871687 if len (needVars ) > 0 {
@@ -1777,6 +1877,9 @@ func addQueryIfUnique(qctx context.Context, qc *queryContext) error {
17771877 // during the automatic serialization of a structure into JSON.
17781878 predicateName := fmt .Sprintf ("<%v>" , pred .Predicate )
17791879 if pred .Lang != "" {
1880+ if err := validateLangTag (pred .Lang ); err != nil {
1881+ return err
1882+ }
17801883 predicateName = fmt .Sprintf ("%v@%v" , predicateName , pred .Lang )
17811884 }
17821885
@@ -1814,6 +1917,9 @@ func addQueryIfUnique(qctx context.Context, qc *queryContext) error {
18141917 }
18151918 qc .uniqueVars [uniqueVarMapKey ] = uniquePredMeta {queryVar : queryVar }
18161919 } else {
1920+ if err := validateValObjectId (pred .ObjectId ); err != nil {
1921+ return err
1922+ }
18171923 valQueryVar := fmt .Sprintf ("__dgraph_uniquecheck_val_%v__" , uniqueVarMapKey )
18181924 query := fmt .Sprintf (`%v as var(func: eq(%v,%v)){
18191925 uid
0 commit comments