@@ -24,6 +24,7 @@ import (
2424 "fmt"
2525 "math"
2626 "net"
27+ "regexp"
2728 "sort"
2829 "strconv"
2930 "strings"
@@ -744,11 +745,102 @@ func validateMutation(ctx context.Context, edges []*pb.DirectedEdge) error {
744745 return nil
745746}
746747
748+ // validateCondValue checks that a cond string is a well-formed @if(...) or @filter(...)
749+ // clause with balanced parentheses and no trailing content. This prevents DQL injection
750+ // via crafted cond values that close the parenthesized expression and append additional
751+ // query blocks.
752+ func validateCondValue (cond string ) error {
753+ cond = strings .TrimSpace (cond )
754+ if cond == "" {
755+ return nil
756+ }
757+
758+ lower := strings .ToLower (cond )
759+ if ! strings .HasPrefix (lower , "@if(" ) && ! strings .HasPrefix (lower , "@filter(" ) {
760+ return errors .Errorf ("invalid cond value: must start with @if( or @filter(" )
761+ }
762+
763+ openIdx := strings .Index (cond , "(" )
764+ if openIdx == - 1 {
765+ return errors .Errorf ("invalid cond value: missing opening parenthesis" )
766+ }
767+
768+ depth := 0
769+ inString := false
770+ escaped := false
771+ closingIdx := - 1
772+
773+ for i := openIdx ; i < len (cond ); i ++ {
774+ ch := cond [i ]
775+ if escaped {
776+ escaped = false
777+ continue
778+ }
779+ if ch == '\\' {
780+ escaped = true
781+ continue
782+ }
783+ if ch == '"' {
784+ inString = ! inString
785+ continue
786+ }
787+ if inString {
788+ continue
789+ }
790+ if ch == '(' {
791+ depth ++
792+ } else if ch == ')' {
793+ depth --
794+ if depth == 0 {
795+ closingIdx = i
796+ break
797+ }
798+ }
799+ }
800+
801+ if closingIdx == - 1 {
802+ return errors .Errorf ("invalid cond value: unbalanced parentheses" )
803+ }
804+
805+ trailing := strings .TrimSpace (cond [closingIdx + 1 :])
806+ if trailing != "" {
807+ return errors .Errorf ("invalid cond value: unexpected content after condition" )
808+ }
809+
810+ return nil
811+ }
812+
813+ // valVarRegexp matches a valid val(variableName) reference used in upsert mutations.
814+ var valVarRegexp = regexp .MustCompile (`^val\([a-zA-Z_][a-zA-Z0-9_.]*\)$` )
815+
816+ // validateValObjectId checks that an ObjectId starting with "val(" is a well-formed
817+ // val(variableName) reference and contains no injected DQL syntax.
818+ func validateValObjectId (objectId string ) error {
819+ if ! valVarRegexp .MatchString (objectId ) {
820+ return errors .Errorf ("invalid val() reference in ObjectId: %q" , objectId )
821+ }
822+ return nil
823+ }
824+
825+ // langTagRegexp matches a valid BCP 47 language tag (letters, digits, hyphens).
826+ var langTagRegexp = regexp .MustCompile (`^[a-zA-Z]+(-[a-zA-Z0-9]+)*$` )
827+
828+ // validateLangTag checks that a language tag contains only safe characters.
829+ func validateLangTag (lang string ) error {
830+ if lang == "" {
831+ return nil
832+ }
833+ if ! langTagRegexp .MatchString (lang ) {
834+ return errors .Errorf ("invalid language tag: %q" , lang )
835+ }
836+ return nil
837+ }
838+
747839// buildUpsertQuery modifies the query to evaluate the
748840// @if condition defined in Conditional Upsert.
749- func buildUpsertQuery (qc * queryContext ) string {
841+ func buildUpsertQuery (qc * queryContext ) ( string , error ) {
750842 if qc .req .Query == "" || len (qc .gmuList ) == 0 {
751- return qc .req .Query
843+ return qc .req .Query , nil
752844 }
753845
754846 qc .condVars = make ([]string , len (qc .req .Mutations ))
@@ -759,6 +851,10 @@ func buildUpsertQuery(qc *queryContext) string {
759851 for i , gmu := range qc .gmuList {
760852 isCondUpsert := strings .TrimSpace (gmu .Cond ) != ""
761853 if isCondUpsert {
854+ if err := validateCondValue (gmu .Cond ); err != nil {
855+ return "" , err
856+ }
857+
762858 qc .condVars [i ] = fmt .Sprintf ("__dgraph_upsertcheck_%v__" , strconv .Itoa (i ))
763859 qc.uidRes [qc.condVars [i ]] = nil
764860 // @if in upsert is same as @filter in the query
@@ -788,7 +884,7 @@ func buildUpsertQuery(qc *queryContext) string {
788884 }
789885
790886 x .Check2 (upsertQB .WriteString (`}` ))
791- return upsertQB .String ()
887+ return upsertQB .String (), nil
792888}
793889
794890// updateMutations updates the mutation and replaces uid(var) and val(var) with
@@ -1599,7 +1695,12 @@ func parseRequest(ctx context.Context, qc *queryContext) error {
15991695
16001696 qc .uidRes = make (map [string ][]string )
16011697 qc .valRes = make (map [string ]map [uint64 ]types.Val )
1602- upsertQuery = buildUpsertQuery (qc )
1698+ var err error
1699+ upsertQuery , err = buildUpsertQuery (qc )
1700+ if err != nil {
1701+ return err
1702+ }
1703+
16031704 needVars = findMutationVars (qc )
16041705 if upsertQuery == "" {
16051706 if len (needVars ) > 0 {
@@ -1746,6 +1847,9 @@ func addQueryIfUnique(qctx context.Context, qc *queryContext) error {
17461847 // during the automatic serialization of a structure into JSON.
17471848 predicateName := fmt .Sprintf ("<%v>" , pred .Predicate )
17481849 if pred .Lang != "" {
1850+ if err := validateLangTag (pred .Lang ); err != nil {
1851+ return err
1852+ }
17491853 predicateName = fmt .Sprintf ("%v@%v" , predicateName , pred .Lang )
17501854 }
17511855
@@ -1780,6 +1884,9 @@ func addQueryIfUnique(qctx context.Context, qc *queryContext) error {
17801884 }
17811885 qc .uniqueVars [uniqueVarMapKey ] = uniquePredMeta {queryVar : queryVar }
17821886 } else {
1887+ if err := validateValObjectId (pred .ObjectId ); err != nil {
1888+ return err
1889+ }
17831890 valQueryVar := fmt .Sprintf ("__dgraph_uniquecheck_val_%v__" , uniqueVarMapKey )
17841891 query := fmt .Sprintf (`%v as var(func: eq(%v,%v)){
17851892 uid
0 commit comments