Skip to content

Commit d61c91a

Browse files
Add support for ef and distance_threshold in generated GraphQL queries
1 parent 4e881c0 commit d61c91a

5 files changed

Lines changed: 309 additions & 68 deletions

File tree

graphql/resolve/query_rewriter.go

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,9 @@ func rewriteAsSimilarByIdQuery(
651651
distanceFormula := "math(sqrt((v2 - v1) dot (v2 - v1)))" // default - euclidean
652652

653653
if metric == schema.SimilarSearchMetricDotProduct {
654-
distanceFormula = "math((1.0 - (v1 dot v2)) /2.0)"
654+
distanceFormula = "math(1.0 - (v1 dot v2))"
655655
} else if metric == schema.SimilarSearchMetricCosine {
656-
distanceFormula = "math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)"
656+
distanceFormula = "math(1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) )))"
657657
}
658658

659659
// First generate the query to fetch the uid
@@ -698,6 +698,32 @@ func rewriteAsSimilarByIdQuery(
698698
// v2 as Product.embedding
699699
// distance as math((v2 - v1) dot (v2 - v1))
700700
// }
701+
similarToArgs := []dql.Arg{
702+
{
703+
Value: pred,
704+
},
705+
{
706+
Value: fmt.Sprintf("%v", topK),
707+
},
708+
{
709+
Value: "val(v1)",
710+
},
711+
}
712+
713+
// Add optional ef parameter if provided (using "ef: value" format)
714+
if ef := query.ArgValue(schema.SimilarEfArgName); ef != nil {
715+
similarToArgs = append(similarToArgs,
716+
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarEfArgName, ef)},
717+
)
718+
}
719+
720+
// Add optional distance_threshold parameter if provided (using "distance_threshold: value" format)
721+
if dt := query.ArgValue(schema.SimilarDistanceThresholdArgName); dt != nil {
722+
similarToArgs = append(similarToArgs,
723+
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarDistanceThresholdArgName, dt)},
724+
)
725+
}
726+
701727
similarQuery := &dql.GraphQuery{
702728
Attr: "var",
703729
Children: []*dql.GraphQuery{
@@ -712,17 +738,7 @@ func rewriteAsSimilarByIdQuery(
712738
},
713739
Func: &dql.Function{
714740
Name: "similar_to",
715-
Args: []dql.Arg{
716-
{
717-
Value: pred,
718-
},
719-
{
720-
Value: fmt.Sprintf("%v", topK),
721-
},
722-
{
723-
Value: "val(v1)",
724-
},
725-
},
741+
Args: similarToArgs,
726742
},
727743
}
728744

@@ -811,10 +827,10 @@ func rewriteAsSimilarByEmbeddingQuery(
811827
distanceFormula := "math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))" // default = euclidean
812828

813829
if metric == schema.SimilarSearchMetricDotProduct {
814-
distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)"
830+
distanceFormula = "math(1.0 - (($search_vector) dot v2))"
815831
} else if metric == schema.SimilarSearchMetricCosine {
816-
distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
817-
" * (v2 dot v2) ) )) / 2.0)"
832+
distanceFormula = "math(1.0 - ((($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
833+
" * (v2 dot v2) )))"
818834
}
819835

820836
// Save vectorString as a query variable, $search_vector
@@ -834,19 +850,35 @@ func rewriteAsSimilarByEmbeddingQuery(
834850
// Create similar_to as the root function, passing $search_vector as
835851
// the search vector
836852
dgQuery[0].Attr = "var"
853+
similarToArgs := []dql.Arg{
854+
{
855+
Value: pred,
856+
},
857+
{
858+
Value: fmt.Sprintf("%v", topK),
859+
},
860+
{
861+
Value: "$search_vector",
862+
},
863+
}
864+
865+
// Add optional ef parameter if provided (using "ef: value" format)
866+
if ef := query.ArgValue(schema.SimilarEfArgName); ef != nil {
867+
similarToArgs = append(similarToArgs,
868+
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarEfArgName, ef)},
869+
)
870+
}
871+
872+
// Add optional distance_threshold parameter if provided (using "distance_threshold: value" format)
873+
if dt := query.ArgValue(schema.SimilarDistanceThresholdArgName); dt != nil {
874+
similarToArgs = append(similarToArgs,
875+
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarDistanceThresholdArgName, dt)},
876+
)
877+
}
878+
837879
dgQuery[0].Func = &dql.Function{
838880
Name: "similar_to",
839-
Args: []dql.Arg{
840-
{
841-
Value: pred,
842-
},
843-
{
844-
Value: fmt.Sprintf("%v", topK),
845-
},
846-
{
847-
Value: "$search_vector",
848-
},
849-
},
881+
Args: similarToArgs,
850882
}
851883

852884
// Compute the euclidean distance between the neighbor

graphql/resolve/query_test.yaml

Lines changed: 172 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3432,7 +3432,7 @@
34323432
}
34333433
var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) {
34343434
v2 as ProjectCosine.description_v
3435-
distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)
3435+
distance as math(1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) )))
34363436
}
34373437
querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) {
34383438
ProjectCosine.id : ProjectCosine.id
@@ -3457,7 +3457,7 @@
34573457
query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
34583458
var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) {
34593459
v2 as ProjectCosine.description_v
3460-
distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0)
3460+
distance as math(1.0 - ((($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) )))
34613461
}
34623462
querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) {
34633463
ProjectCosine.id : ProjectCosine.id
@@ -3487,7 +3487,7 @@
34873487
}
34883488
var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) {
34893489
v2 as ProjectDotProduct.description_v
3490-
distance as math((1.0 - (v1 dot v2)) /2.0)
3490+
distance as math(1.0 - (v1 dot v2))
34913491
}
34923492
querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) {
34933493
ProjectDotProduct.id : ProjectDotProduct.id
@@ -3512,7 +3512,7 @@
35123512
query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
35133513
var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) {
35143514
v2 as ProjectDotProduct.description_v
3515-
distance as math(( 1.0 - (($search_vector) dot v2)) /2.0)
3515+
distance as math(1.0 - (($search_vector) dot v2))
35163516
}
35173517
querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
35183518
ProjectDotProduct.id : ProjectDotProduct.id
@@ -3522,3 +3522,171 @@
35223522
ProjectDotProduct.vector_distance : val(distance)
35233523
}
35243524
}
3525+
3526+
- name: query similar_to with ef parameter
3527+
gqlquery: |
3528+
query {
3529+
querySimilarProductByEmbedding(by: productVector, topK: 5, vector: [0.1, 0.2, 0.3, 0.4, 0.5], ef: 64) {
3530+
id
3531+
title
3532+
productVector
3533+
}
3534+
}
3535+
3536+
dgquery: |-
3537+
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
3538+
var(func: similar_to(Product.productVector, 5, $search_vector, ef: 64)) @filter(type(Product)) {
3539+
v2 as Product.productVector
3540+
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
3541+
}
3542+
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
3543+
Product.id : Product.id
3544+
Product.title : Product.title
3545+
Product.productVector : Product.productVector
3546+
dgraph.uid : uid
3547+
Product.vector_distance : val(distance)
3548+
}
3549+
}
3550+
3551+
- name: query similar_to with distance_threshold parameter
3552+
gqlquery: |
3553+
query {
3554+
querySimilarProductByEmbedding(by: productVector, topK: 10, vector: [0.1, 0.2, 0.3, 0.4, 0.5], distance_threshold: 0.5) {
3555+
id
3556+
title
3557+
productVector
3558+
}
3559+
}
3560+
3561+
dgquery: |-
3562+
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
3563+
var(func: similar_to(Product.productVector, 10, $search_vector, distance_threshold: 0.5)) @filter(type(Product)) {
3564+
v2 as Product.productVector
3565+
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
3566+
}
3567+
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
3568+
Product.id : Product.id
3569+
Product.title : Product.title
3570+
Product.productVector : Product.productVector
3571+
dgraph.uid : uid
3572+
Product.vector_distance : val(distance)
3573+
}
3574+
}
3575+
3576+
- name: query similar_to with both ef and distance_threshold parameters
3577+
gqlquery: |
3578+
query {
3579+
querySimilarProductByEmbedding(by: productVector, topK: 8, vector: [0.1, 0.2, 0.3, 0.4, 0.5], ef: 128, distance_threshold: 0.75) {
3580+
id
3581+
title
3582+
productVector
3583+
}
3584+
}
3585+
3586+
dgquery: |-
3587+
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
3588+
var(func: similar_to(Product.productVector, 8, $search_vector, ef: 128, distance_threshold: 0.75)) @filter(type(Product)) {
3589+
v2 as Product.productVector
3590+
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
3591+
}
3592+
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
3593+
Product.id : Product.id
3594+
Product.title : Product.title
3595+
Product.productVector : Product.productVector
3596+
dgraph.uid : uid
3597+
Product.vector_distance : val(distance)
3598+
}
3599+
}
3600+
3601+
- name: query vector by id with ef parameter
3602+
gqlquery: |
3603+
query {
3604+
querySimilarProductById(by: productVector, topK: 5, id: "0x1", ef: 64) {
3605+
id
3606+
title
3607+
productVector
3608+
}
3609+
}
3610+
3611+
dgquery: |-
3612+
query {
3613+
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
3614+
vec as Product.productVector
3615+
}
3616+
var() {
3617+
v1 as max(val(vec))
3618+
}
3619+
var(func: similar_to(Product.productVector, 5, val(v1), ef: 64)) {
3620+
v2 as Product.productVector
3621+
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
3622+
}
3623+
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
3624+
Product.id : Product.id
3625+
Product.title : Product.title
3626+
Product.productVector : Product.productVector
3627+
dgraph.uid : uid
3628+
Product.vector_distance : val(distance)
3629+
}
3630+
}
3631+
3632+
- name: query vector by id with distance_threshold parameter
3633+
gqlquery: |
3634+
query {
3635+
querySimilarProductById(by: productVector, topK: 10, id: "0x1", distance_threshold: 0.5) {
3636+
id
3637+
title
3638+
productVector
3639+
}
3640+
}
3641+
3642+
dgquery: |-
3643+
query {
3644+
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
3645+
vec as Product.productVector
3646+
}
3647+
var() {
3648+
v1 as max(val(vec))
3649+
}
3650+
var(func: similar_to(Product.productVector, 10, val(v1), distance_threshold: 0.5)) {
3651+
v2 as Product.productVector
3652+
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
3653+
}
3654+
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
3655+
Product.id : Product.id
3656+
Product.title : Product.title
3657+
Product.productVector : Product.productVector
3658+
dgraph.uid : uid
3659+
Product.vector_distance : val(distance)
3660+
}
3661+
}
3662+
3663+
- name: query vector by id with both ef and distance_threshold parameters
3664+
gqlquery: |
3665+
query {
3666+
querySimilarProductById(by: productVector, topK: 8, id: "0x1", ef: 128, distance_threshold: 0.75) {
3667+
id
3668+
title
3669+
productVector
3670+
}
3671+
}
3672+
3673+
dgquery: |-
3674+
query {
3675+
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
3676+
vec as Product.productVector
3677+
}
3678+
var() {
3679+
v1 as max(val(vec))
3680+
}
3681+
var(func: similar_to(Product.productVector, 8, val(v1), ef: 128, distance_threshold: 0.75)) {
3682+
v2 as Product.productVector
3683+
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
3684+
}
3685+
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
3686+
Product.id : Product.id
3687+
Product.title : Product.title
3688+
Product.productVector : Product.productVector
3689+
dgraph.uid : uid
3690+
Product.vector_distance : val(distance)
3691+
}
3692+
}

graphql/schema/gqlschema.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,6 +2088,25 @@ func addSimilarByEmbeddingQuery(schema *ast.Schema, defn *ast.Definition) {
20882088
NonNull: true,
20892089
},
20902090
})
2091+
2092+
// Accept optional ef parameter for HNSW search
2093+
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
2094+
Name: SimilarEfArgName,
2095+
Type: &ast.Type{
2096+
NamedType: "Int",
2097+
NonNull: false,
2098+
},
2099+
})
2100+
2101+
// Accept optional distance_threshold parameter for filtering results
2102+
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
2103+
Name: SimilarDistanceThresholdArgName,
2104+
Type: &ast.Type{
2105+
NamedType: "Float",
2106+
NonNull: false,
2107+
},
2108+
})
2109+
20912110
addFilterArgument(schema, qry)
20922111

20932112
schema.Query.Fields = append(schema.Query.Fields, qry)
@@ -2197,6 +2216,24 @@ func addSimilarByIdQuery(schema *ast.Schema, defn *ast.Definition,
21972216
},
21982217
})
21992218

2219+
// Accept optional ef parameter for HNSW search
2220+
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
2221+
Name: SimilarEfArgName,
2222+
Type: &ast.Type{
2223+
NamedType: "Int",
2224+
NonNull: false,
2225+
},
2226+
})
2227+
2228+
// Accept optional distance_threshold parameter for filtering results
2229+
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
2230+
Name: SimilarDistanceThresholdArgName,
2231+
Type: &ast.Type{
2232+
NamedType: "Float",
2233+
NonNull: false,
2234+
},
2235+
})
2236+
22002237
addFilterArgument(schema, qry)
22012238
schema.Query.Fields = append(schema.Query.Fields, qry)
22022239
}

0 commit comments

Comments
 (0)