Skip to content

Commit 96f976d

Browse files
feat(graphql): Add support for ef and distance_threshold in generated GraphQL queries for similarity search (#9562)
**Description** This PR adds support for the new effort and distance threshold in generated GraphQL queries for types that have embedding vectors. Note. There's a breaking change in the resulting computed distance in that for cosine and dotproduct indexes, the computed distance is no longer divided by 2. The results now correctly represent the distances stored in the index. **Checklist** - [x] The PR title follows the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/#summary) syntax, leading with `fix:`, `feat:`, `chore:`, `ci:`, etc. - [x] Code compiles correctly and linting (via trunk) passes locally - [x] Tests added for new functionality, or regression tests for bug fixes added as applicable - [ ] For public APIs, new features, etc., a PR on the [docs repo](https://github.com/dgraph-io/dgraph-docs) staged and linked here. This process can be simplified by going to the [public docs site](https://docs.dgraph.io/) and clicking the "Edit this page" button at the bottom of page(s) relevant to your changes. Ensure that you indicate in the PR that this is an **unreleased** feature so that it does not get merged into the main docs prematurely.
1 parent 4e881c0 commit 96f976d

5 files changed

Lines changed: 309 additions & 68 deletions

File tree

graphql/resolve/query_rewriter.go

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,9 @@ func rewriteAsSimilarByIdQuery(
651651
distanceFormula := "math(sqrt((v2 - v1) dot (v2 - v1)))" // default - euclidean
652652

653653
if metric == schema.SimilarSearchMetricDotProduct {
654-
distanceFormula = "math((1.0 - (v1 dot v2)) /2.0)"
654+
distanceFormula = "math(1.0 - (v1 dot v2))"
655655
} else if metric == schema.SimilarSearchMetricCosine {
656-
distanceFormula = "math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)"
656+
distanceFormula = "math(1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) )))"
657657
}
658658

659659
// First generate the query to fetch the uid
@@ -698,6 +698,32 @@ func rewriteAsSimilarByIdQuery(
698698
// v2 as Product.embedding
699699
// distance as math((v2 - v1) dot (v2 - v1))
700700
// }
701+
similarToArgs := []dql.Arg{
702+
{
703+
Value: pred,
704+
},
705+
{
706+
Value: fmt.Sprintf("%v", topK),
707+
},
708+
{
709+
Value: "val(v1)",
710+
},
711+
}
712+
713+
// Add optional ef parameter if provided (using "ef: value" format)
714+
if ef := query.ArgValue(schema.SimilarEfArgName); ef != nil {
715+
similarToArgs = append(similarToArgs,
716+
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarEfArgName, ef)},
717+
)
718+
}
719+
720+
// Add optional distance_threshold parameter if provided (using "distance_threshold: value" format)
721+
if dt := query.ArgValue(schema.SimilarDistanceThresholdArgName); dt != nil {
722+
similarToArgs = append(similarToArgs,
723+
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarDistanceThresholdArgName, dt)},
724+
)
725+
}
726+
701727
similarQuery := &dql.GraphQuery{
702728
Attr: "var",
703729
Children: []*dql.GraphQuery{
@@ -712,17 +738,7 @@ func rewriteAsSimilarByIdQuery(
712738
},
713739
Func: &dql.Function{
714740
Name: "similar_to",
715-
Args: []dql.Arg{
716-
{
717-
Value: pred,
718-
},
719-
{
720-
Value: fmt.Sprintf("%v", topK),
721-
},
722-
{
723-
Value: "val(v1)",
724-
},
725-
},
741+
Args: similarToArgs,
726742
},
727743
}
728744

@@ -811,10 +827,10 @@ func rewriteAsSimilarByEmbeddingQuery(
811827
distanceFormula := "math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))" // default = euclidean
812828

813829
if metric == schema.SimilarSearchMetricDotProduct {
814-
distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)"
830+
distanceFormula = "math(1.0 - (($search_vector) dot v2))"
815831
} else if metric == schema.SimilarSearchMetricCosine {
816-
distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
817-
" * (v2 dot v2) ) )) / 2.0)"
832+
distanceFormula = "math(1.0 - ((($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
833+
" * (v2 dot v2) )))"
818834
}
819835

820836
// Save vectorString as a query variable, $search_vector
@@ -834,19 +850,35 @@ func rewriteAsSimilarByEmbeddingQuery(
834850
// Create similar_to as the root function, passing $search_vector as
835851
// the search vector
836852
dgQuery[0].Attr = "var"
853+
similarToArgs := []dql.Arg{
854+
{
855+
Value: pred,
856+
},
857+
{
858+
Value: fmt.Sprintf("%v", topK),
859+
},
860+
{
861+
Value: "$search_vector",
862+
},
863+
}
864+
865+
// Add optional ef parameter if provided (using "ef: value" format)
866+
if ef := query.ArgValue(schema.SimilarEfArgName); ef != nil {
867+
similarToArgs = append(similarToArgs,
868+
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarEfArgName, ef)},
869+
)
870+
}
871+
872+
// Add optional distance_threshold parameter if provided (using "distance_threshold: value" format)
873+
if dt := query.ArgValue(schema.SimilarDistanceThresholdArgName); dt != nil {
874+
similarToArgs = append(similarToArgs,
875+
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarDistanceThresholdArgName, dt)},
876+
)
877+
}
878+
837879
dgQuery[0].Func = &dql.Function{
838880
Name: "similar_to",
839-
Args: []dql.Arg{
840-
{
841-
Value: pred,
842-
},
843-
{
844-
Value: fmt.Sprintf("%v", topK),
845-
},
846-
{
847-
Value: "$search_vector",
848-
},
849-
},
881+
Args: similarToArgs,
850882
}
851883

852884
// Compute the euclidean distance between the neighbor

graphql/resolve/query_test.yaml

Lines changed: 172 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3432,7 +3432,7 @@
34323432
}
34333433
var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) {
34343434
v2 as ProjectCosine.description_v
3435-
distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)
3435+
distance as math(1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) )))
34363436
}
34373437
querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) {
34383438
ProjectCosine.id : ProjectCosine.id
@@ -3457,7 +3457,7 @@
34573457
query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
34583458
var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) {
34593459
v2 as ProjectCosine.description_v
3460-
distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0)
3460+
distance as math(1.0 - ((($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) )))
34613461
}
34623462
querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) {
34633463
ProjectCosine.id : ProjectCosine.id
@@ -3487,7 +3487,7 @@
34873487
}
34883488
var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) {
34893489
v2 as ProjectDotProduct.description_v
3490-
distance as math((1.0 - (v1 dot v2)) /2.0)
3490+
distance as math(1.0 - (v1 dot v2))
34913491
}
34923492
querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) {
34933493
ProjectDotProduct.id : ProjectDotProduct.id
@@ -3512,7 +3512,7 @@
35123512
query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
35133513
var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) {
35143514
v2 as ProjectDotProduct.description_v
3515-
distance as math(( 1.0 - (($search_vector) dot v2)) /2.0)
3515+
distance as math(1.0 - (($search_vector) dot v2))
35163516
}
35173517
querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
35183518
ProjectDotProduct.id : ProjectDotProduct.id
@@ -3522,3 +3522,171 @@
35223522
ProjectDotProduct.vector_distance : val(distance)
35233523
}
35243524
}
3525+
3526+
- name: query similar_to with ef parameter
3527+
gqlquery: |
3528+
query {
3529+
querySimilarProductByEmbedding(by: productVector, topK: 5, vector: [0.1, 0.2, 0.3, 0.4, 0.5], ef: 64) {
3530+
id
3531+
title
3532+
productVector
3533+
}
3534+
}
3535+
3536+
dgquery: |-
3537+
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
3538+
var(func: similar_to(Product.productVector, 5, $search_vector, ef: 64)) @filter(type(Product)) {
3539+
v2 as Product.productVector
3540+
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
3541+
}
3542+
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
3543+
Product.id : Product.id
3544+
Product.title : Product.title
3545+
Product.productVector : Product.productVector
3546+
dgraph.uid : uid
3547+
Product.vector_distance : val(distance)
3548+
}
3549+
}
3550+
3551+
- name: query similar_to with distance_threshold parameter
3552+
gqlquery: |
3553+
query {
3554+
querySimilarProductByEmbedding(by: productVector, topK: 10, vector: [0.1, 0.2, 0.3, 0.4, 0.5], distance_threshold: 0.5) {
3555+
id
3556+
title
3557+
productVector
3558+
}
3559+
}
3560+
3561+
dgquery: |-
3562+
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
3563+
var(func: similar_to(Product.productVector, 10, $search_vector, distance_threshold: 0.5)) @filter(type(Product)) {
3564+
v2 as Product.productVector
3565+
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
3566+
}
3567+
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
3568+
Product.id : Product.id
3569+
Product.title : Product.title
3570+
Product.productVector : Product.productVector
3571+
dgraph.uid : uid
3572+
Product.vector_distance : val(distance)
3573+
}
3574+
}
3575+
3576+
- name: query similar_to with both ef and distance_threshold parameters
3577+
gqlquery: |
3578+
query {
3579+
querySimilarProductByEmbedding(by: productVector, topK: 8, vector: [0.1, 0.2, 0.3, 0.4, 0.5], ef: 128, distance_threshold: 0.75) {
3580+
id
3581+
title
3582+
productVector
3583+
}
3584+
}
3585+
3586+
dgquery: |-
3587+
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
3588+
var(func: similar_to(Product.productVector, 8, $search_vector, ef: 128, distance_threshold: 0.75)) @filter(type(Product)) {
3589+
v2 as Product.productVector
3590+
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
3591+
}
3592+
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
3593+
Product.id : Product.id
3594+
Product.title : Product.title
3595+
Product.productVector : Product.productVector
3596+
dgraph.uid : uid
3597+
Product.vector_distance : val(distance)
3598+
}
3599+
}
3600+
3601+
- name: query vector by id with ef parameter
3602+
gqlquery: |
3603+
query {
3604+
querySimilarProductById(by: productVector, topK: 5, id: "0x1", ef: 64) {
3605+
id
3606+
title
3607+
productVector
3608+
}
3609+
}
3610+
3611+
dgquery: |-
3612+
query {
3613+
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
3614+
vec as Product.productVector
3615+
}
3616+
var() {
3617+
v1 as max(val(vec))
3618+
}
3619+
var(func: similar_to(Product.productVector, 5, val(v1), ef: 64)) {
3620+
v2 as Product.productVector
3621+
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
3622+
}
3623+
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
3624+
Product.id : Product.id
3625+
Product.title : Product.title
3626+
Product.productVector : Product.productVector
3627+
dgraph.uid : uid
3628+
Product.vector_distance : val(distance)
3629+
}
3630+
}
3631+
3632+
- name: query vector by id with distance_threshold parameter
3633+
gqlquery: |
3634+
query {
3635+
querySimilarProductById(by: productVector, topK: 10, id: "0x1", distance_threshold: 0.5) {
3636+
id
3637+
title
3638+
productVector
3639+
}
3640+
}
3641+
3642+
dgquery: |-
3643+
query {
3644+
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
3645+
vec as Product.productVector
3646+
}
3647+
var() {
3648+
v1 as max(val(vec))
3649+
}
3650+
var(func: similar_to(Product.productVector, 10, val(v1), distance_threshold: 0.5)) {
3651+
v2 as Product.productVector
3652+
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
3653+
}
3654+
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
3655+
Product.id : Product.id
3656+
Product.title : Product.title
3657+
Product.productVector : Product.productVector
3658+
dgraph.uid : uid
3659+
Product.vector_distance : val(distance)
3660+
}
3661+
}
3662+
3663+
- name: query vector by id with both ef and distance_threshold parameters
3664+
gqlquery: |
3665+
query {
3666+
querySimilarProductById(by: productVector, topK: 8, id: "0x1", ef: 128, distance_threshold: 0.75) {
3667+
id
3668+
title
3669+
productVector
3670+
}
3671+
}
3672+
3673+
dgquery: |-
3674+
query {
3675+
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
3676+
vec as Product.productVector
3677+
}
3678+
var() {
3679+
v1 as max(val(vec))
3680+
}
3681+
var(func: similar_to(Product.productVector, 8, val(v1), ef: 128, distance_threshold: 0.75)) {
3682+
v2 as Product.productVector
3683+
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
3684+
}
3685+
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
3686+
Product.id : Product.id
3687+
Product.title : Product.title
3688+
Product.productVector : Product.productVector
3689+
dgraph.uid : uid
3690+
Product.vector_distance : val(distance)
3691+
}
3692+
}

graphql/schema/gqlschema.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,6 +2088,25 @@ func addSimilarByEmbeddingQuery(schema *ast.Schema, defn *ast.Definition) {
20882088
NonNull: true,
20892089
},
20902090
})
2091+
2092+
// Accept optional ef parameter for HNSW search
2093+
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
2094+
Name: SimilarEfArgName,
2095+
Type: &ast.Type{
2096+
NamedType: "Int",
2097+
NonNull: false,
2098+
},
2099+
})
2100+
2101+
// Accept optional distance_threshold parameter for filtering results
2102+
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
2103+
Name: SimilarDistanceThresholdArgName,
2104+
Type: &ast.Type{
2105+
NamedType: "Float",
2106+
NonNull: false,
2107+
},
2108+
})
2109+
20912110
addFilterArgument(schema, qry)
20922111

20932112
schema.Query.Fields = append(schema.Query.Fields, qry)
@@ -2197,6 +2216,24 @@ func addSimilarByIdQuery(schema *ast.Schema, defn *ast.Definition,
21972216
},
21982217
})
21992218

2219+
// Accept optional ef parameter for HNSW search
2220+
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
2221+
Name: SimilarEfArgName,
2222+
Type: &ast.Type{
2223+
NamedType: "Int",
2224+
NonNull: false,
2225+
},
2226+
})
2227+
2228+
// Accept optional distance_threshold parameter for filtering results
2229+
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
2230+
Name: SimilarDistanceThresholdArgName,
2231+
Type: &ast.Type{
2232+
NamedType: "Float",
2233+
NonNull: false,
2234+
},
2235+
})
2236+
22002237
addFilterArgument(schema, qry)
22012238
schema.Query.Fields = append(schema.Query.Fields, qry)
22022239
}

0 commit comments

Comments
 (0)