diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index 20e6e511d2a..e5a69e8f9c1 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -651,9 +651,9 @@ func rewriteAsSimilarByIdQuery( distanceFormula := "math(sqrt((v2 - v1) dot (v2 - v1)))" // default - euclidean if metric == schema.SimilarSearchMetricDotProduct { - distanceFormula = "math((1.0 - (v1 dot v2)) /2.0)" + distanceFormula = "math(1.0 - (v1 dot v2))" } else if metric == schema.SimilarSearchMetricCosine { - distanceFormula = "math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)" + distanceFormula = "math(1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) )))" } // First generate the query to fetch the uid @@ -698,6 +698,32 @@ func rewriteAsSimilarByIdQuery( // v2 as Product.embedding // distance as math((v2 - v1) dot (v2 - v1)) // } + similarToArgs := []dql.Arg{ + { + Value: pred, + }, + { + Value: fmt.Sprintf("%v", topK), + }, + { + Value: "val(v1)", + }, + } + + // Add optional ef parameter if provided (using "ef: value" format) + if ef := query.ArgValue(schema.SimilarEfArgName); ef != nil { + similarToArgs = append(similarToArgs, + dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarEfArgName, ef)}, + ) + } + + // Add optional distance_threshold parameter if provided (using "distance_threshold: value" format) + if dt := query.ArgValue(schema.SimilarDistanceThresholdArgName); dt != nil { + similarToArgs = append(similarToArgs, + dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarDistanceThresholdArgName, dt)}, + ) + } + similarQuery := &dql.GraphQuery{ Attr: "var", Children: []*dql.GraphQuery{ @@ -712,17 +738,7 @@ func rewriteAsSimilarByIdQuery( }, Func: &dql.Function{ Name: "similar_to", - Args: []dql.Arg{ - { - Value: pred, - }, - { - Value: fmt.Sprintf("%v", topK), - }, - { - Value: "val(v1)", - }, - }, + Args: similarToArgs, }, } @@ -811,10 +827,10 @@ func rewriteAsSimilarByEmbeddingQuery( distanceFormula := "math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))" // default = euclidean if metric == schema.SimilarSearchMetricDotProduct { - distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)" + distanceFormula = "math(1.0 - (($search_vector) dot v2))" } else if metric == schema.SimilarSearchMetricCosine { - distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" + - " * (v2 dot v2) ) )) / 2.0)" + distanceFormula = "math(1.0 - ((($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" + + " * (v2 dot v2) )))" } // Save vectorString as a query variable, $search_vector @@ -834,19 +850,35 @@ func rewriteAsSimilarByEmbeddingQuery( // Create similar_to as the root function, passing $search_vector as // the search vector dgQuery[0].Attr = "var" + similarToArgs := []dql.Arg{ + { + Value: pred, + }, + { + Value: fmt.Sprintf("%v", topK), + }, + { + Value: "$search_vector", + }, + } + + // Add optional ef parameter if provided (using "ef: value" format) + if ef := query.ArgValue(schema.SimilarEfArgName); ef != nil { + similarToArgs = append(similarToArgs, + dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarEfArgName, ef)}, + ) + } + + // Add optional distance_threshold parameter if provided (using "distance_threshold: value" format) + if dt := query.ArgValue(schema.SimilarDistanceThresholdArgName); dt != nil { + similarToArgs = append(similarToArgs, + dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarDistanceThresholdArgName, dt)}, + ) + } + dgQuery[0].Func = &dql.Function{ Name: "similar_to", - Args: []dql.Arg{ - { - Value: pred, - }, - { - Value: fmt.Sprintf("%v", topK), - }, - { - Value: "$search_vector", - }, - }, + Args: similarToArgs, } // Compute the euclidean distance between the neighbor diff --git a/graphql/resolve/query_test.yaml b/graphql/resolve/query_test.yaml index a9bae8e59b8..3465384e575 100644 --- a/graphql/resolve/query_test.yaml +++ b/graphql/resolve/query_test.yaml @@ -3432,7 +3432,7 @@ } var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) { v2 as ProjectCosine.description_v - distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0) + distance as math(1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ))) } querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) { ProjectCosine.id : ProjectCosine.id @@ -3457,7 +3457,7 @@ query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) { v2 as ProjectCosine.description_v - distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0) + distance as math(1.0 - ((($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ))) } querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) { ProjectCosine.id : ProjectCosine.id @@ -3487,7 +3487,7 @@ } var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) { v2 as ProjectDotProduct.description_v - distance as math((1.0 - (v1 dot v2)) /2.0) + distance as math(1.0 - (v1 dot v2)) } querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) { ProjectDotProduct.id : ProjectDotProduct.id @@ -3512,7 +3512,7 @@ query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) { v2 as ProjectDotProduct.description_v - distance as math(( 1.0 - (($search_vector) dot v2)) /2.0) + distance as math(1.0 - (($search_vector) dot v2)) } querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) { ProjectDotProduct.id : ProjectDotProduct.id @@ -3522,3 +3522,171 @@ ProjectDotProduct.vector_distance : val(distance) } } + +- name: query similar_to with ef parameter + gqlquery: | + query { + querySimilarProductByEmbedding(by: productVector, topK: 5, vector: [0.1, 0.2, 0.3, 0.4, 0.5], ef: 64) { + id + title + productVector + } + } + + dgquery: |- + query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { + var(func: similar_to(Product.productVector, 5, $search_vector, ef: 64)) @filter(type(Product)) { + v2 as Product.productVector + distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector))) + } + querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) { + Product.id : Product.id + Product.title : Product.title + Product.productVector : Product.productVector + dgraph.uid : uid + Product.vector_distance : val(distance) + } + } + +- name: query similar_to with distance_threshold parameter + gqlquery: | + query { + querySimilarProductByEmbedding(by: productVector, topK: 10, vector: [0.1, 0.2, 0.3, 0.4, 0.5], distance_threshold: 0.5) { + id + title + productVector + } + } + + dgquery: |- + query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { + var(func: similar_to(Product.productVector, 10, $search_vector, distance_threshold: 0.5)) @filter(type(Product)) { + v2 as Product.productVector + distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector))) + } + querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) { + Product.id : Product.id + Product.title : Product.title + Product.productVector : Product.productVector + dgraph.uid : uid + Product.vector_distance : val(distance) + } + } + +- name: query similar_to with both ef and distance_threshold parameters + gqlquery: | + query { + querySimilarProductByEmbedding(by: productVector, topK: 8, vector: [0.1, 0.2, 0.3, 0.4, 0.5], ef: 128, distance_threshold: 0.75) { + id + title + productVector + } + } + + dgquery: |- + query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { + var(func: similar_to(Product.productVector, 8, $search_vector, ef: 128, distance_threshold: 0.75)) @filter(type(Product)) { + v2 as Product.productVector + distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector))) + } + querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) { + Product.id : Product.id + Product.title : Product.title + Product.productVector : Product.productVector + dgraph.uid : uid + Product.vector_distance : val(distance) + } + } + +- name: query vector by id with ef parameter + gqlquery: | + query { + querySimilarProductById(by: productVector, topK: 5, id: "0x1", ef: 64) { + id + title + productVector + } + } + + dgquery: |- + query { + var(func: eq(Product.id, "0x1")) @filter(type(Product)) { + vec as Product.productVector + } + var() { + v1 as max(val(vec)) + } + var(func: similar_to(Product.productVector, 5, val(v1), ef: 64)) { + v2 as Product.productVector + distance as math(sqrt((v2 - v1) dot (v2 - v1))) + } + querySimilarProductById(func: uid(distance), orderasc: val(distance)) { + Product.id : Product.id + Product.title : Product.title + Product.productVector : Product.productVector + dgraph.uid : uid + Product.vector_distance : val(distance) + } + } + +- name: query vector by id with distance_threshold parameter + gqlquery: | + query { + querySimilarProductById(by: productVector, topK: 10, id: "0x1", distance_threshold: 0.5) { + id + title + productVector + } + } + + dgquery: |- + query { + var(func: eq(Product.id, "0x1")) @filter(type(Product)) { + vec as Product.productVector + } + var() { + v1 as max(val(vec)) + } + var(func: similar_to(Product.productVector, 10, val(v1), distance_threshold: 0.5)) { + v2 as Product.productVector + distance as math(sqrt((v2 - v1) dot (v2 - v1))) + } + querySimilarProductById(func: uid(distance), orderasc: val(distance)) { + Product.id : Product.id + Product.title : Product.title + Product.productVector : Product.productVector + dgraph.uid : uid + Product.vector_distance : val(distance) + } + } + +- name: query vector by id with both ef and distance_threshold parameters + gqlquery: | + query { + querySimilarProductById(by: productVector, topK: 8, id: "0x1", ef: 128, distance_threshold: 0.75) { + id + title + productVector + } + } + + dgquery: |- + query { + var(func: eq(Product.id, "0x1")) @filter(type(Product)) { + vec as Product.productVector + } + var() { + v1 as max(val(vec)) + } + var(func: similar_to(Product.productVector, 8, val(v1), ef: 128, distance_threshold: 0.75)) { + v2 as Product.productVector + distance as math(sqrt((v2 - v1) dot (v2 - v1))) + } + querySimilarProductById(func: uid(distance), orderasc: val(distance)) { + Product.id : Product.id + Product.title : Product.title + Product.productVector : Product.productVector + dgraph.uid : uid + Product.vector_distance : val(distance) + } + } diff --git a/graphql/schema/gqlschema.go b/graphql/schema/gqlschema.go index ae4cf684295..32eda06f88f 100644 --- a/graphql/schema/gqlschema.go +++ b/graphql/schema/gqlschema.go @@ -2088,6 +2088,25 @@ func addSimilarByEmbeddingQuery(schema *ast.Schema, defn *ast.Definition) { NonNull: true, }, }) + + // Accept optional ef parameter for HNSW search + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarEfArgName, + Type: &ast.Type{ + NamedType: "Int", + NonNull: false, + }, + }) + + // Accept optional distance_threshold parameter for filtering results + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarDistanceThresholdArgName, + Type: &ast.Type{ + NamedType: "Float", + NonNull: false, + }, + }) + addFilterArgument(schema, qry) schema.Query.Fields = append(schema.Query.Fields, qry) @@ -2197,6 +2216,24 @@ func addSimilarByIdQuery(schema *ast.Schema, defn *ast.Definition, }, }) + // Accept optional ef parameter for HNSW search + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarEfArgName, + Type: &ast.Type{ + NamedType: "Int", + NonNull: false, + }, + }) + + // Accept optional distance_threshold parameter for filtering results + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarDistanceThresholdArgName, + Type: &ast.Type{ + NamedType: "Float", + NonNull: false, + }, + }) + addFilterArgument(schema, qry) schema.Query.Fields = append(schema.Query.Fields, qry) } diff --git a/graphql/schema/testdata/schemagen/output/embedding-directive-with-similar-queries.graphql b/graphql/schema/testdata/schemagen/output/embedding-directive-with-similar-queries.graphql index f7b1699c512..351d2eb769b 100644 --- a/graphql/schema/testdata/schemagen/output/embedding-directive-with-similar-queries.graphql +++ b/graphql/schema/testdata/schemagen/output/embedding-directive-with-similar-queries.graphql @@ -553,15 +553,15 @@ input UserRef { type Query { getProduct(id: String!): Product - querySimilarProductById(id: String!, by: ProductEmbedding!, topK: Int!, filter: ProductFilter): [Product] - querySimilarProductByEmbedding(by: ProductEmbedding!, topK: Int!, vector: [Float!]!, filter: ProductFilter): [Product] + querySimilarProductById(id: String!, by: ProductEmbedding!, topK: Int!, ef: Int, distance_threshold: Float, filter: ProductFilter): [Product] + querySimilarProductByEmbedding(by: ProductEmbedding!, topK: Int!, vector: [Float!]!, ef: Int, distance_threshold: Float, filter: ProductFilter): [Product] queryProduct(filter: ProductFilter, order: ProductOrder, first: Int, offset: Int): [Product] aggregateProduct(filter: ProductFilter): ProductAggregateResult queryPurchase(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase] aggregatePurchase(filter: PurchaseFilter): PurchaseAggregateResult getUser(email: String!): User - querySimilarUserById(email: String!, by: UserEmbedding!, topK: Int!, filter: UserFilter): [User] - querySimilarUserByEmbedding(by: UserEmbedding!, topK: Int!, vector: [Float!]!, filter: UserFilter): [User] + querySimilarUserById(email: String!, by: UserEmbedding!, topK: Int!, ef: Int, distance_threshold: Float, filter: UserFilter): [User] + querySimilarUserByEmbedding(by: UserEmbedding!, topK: Int!, vector: [Float!]!, ef: Int, distance_threshold: Float, filter: UserFilter): [User] queryUser(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User] aggregateUser(filter: UserFilter): UserAggregateResult } diff --git a/graphql/schema/wrappers.go b/graphql/schema/wrappers.go index 8837d5170d8..f30df66a50d 100644 --- a/graphql/schema/wrappers.go +++ b/graphql/schema/wrappers.go @@ -76,38 +76,40 @@ type EntityRepresentations struct { // Query/Mutation types and arg names const ( - GetQuery QueryType = "get" - SimilarByIdQuery QueryType = "querySimilarById" - SimilarByEmbeddingQuery QueryType = "querySimilarByEmbedding" - FilterQuery QueryType = "query" - AggregateQuery QueryType = "aggregate" - SchemaQuery QueryType = "schema" - EntitiesQuery QueryType = "entities" - PasswordQuery QueryType = "checkPassword" - HTTPQuery QueryType = "http" - DQLQuery QueryType = "dql" - NotSupportedQuery QueryType = "notsupported" - AddMutation MutationType = "add" - UpdateMutation MutationType = "update" - DeleteMutation MutationType = "delete" - HTTPMutation MutationType = "http" - NotSupportedMutation MutationType = "notsupported" - IDType = "ID" - InputArgName = "input" - UpsertArgName = "upsert" - FilterArgName = "filter" - SimilarByArgName = "by" - SimilarTopKArgName = "topK" - SimilarVectorArgName = "vector" - EmbeddingEnumSuffix = "Embedding" - SimilarQueryPrefix = "querySimilar" - SimilarByIdQuerySuffix = "ById" - SimilarByEmbeddingQuerySuffix = "ByEmbedding" - SimilarQueryResultTypeSuffix = "WithDistance" - SimilarQueryDistanceFieldName = "vector_distance" - SimilarSearchMetricEuclidean = "euclidean" - SimilarSearchMetricDotProduct = "dotproduct" - SimilarSearchMetricCosine = "cosine" + GetQuery QueryType = "get" + SimilarByIdQuery QueryType = "querySimilarById" + SimilarByEmbeddingQuery QueryType = "querySimilarByEmbedding" + FilterQuery QueryType = "query" + AggregateQuery QueryType = "aggregate" + SchemaQuery QueryType = "schema" + EntitiesQuery QueryType = "entities" + PasswordQuery QueryType = "checkPassword" + HTTPQuery QueryType = "http" + DQLQuery QueryType = "dql" + NotSupportedQuery QueryType = "notsupported" + AddMutation MutationType = "add" + UpdateMutation MutationType = "update" + DeleteMutation MutationType = "delete" + HTTPMutation MutationType = "http" + NotSupportedMutation MutationType = "notsupported" + IDType = "ID" + InputArgName = "input" + UpsertArgName = "upsert" + FilterArgName = "filter" + SimilarByArgName = "by" + SimilarTopKArgName = "topK" + SimilarVectorArgName = "vector" + SimilarEfArgName = "ef" + SimilarDistanceThresholdArgName = "distance_threshold" + EmbeddingEnumSuffix = "Embedding" + SimilarQueryPrefix = "querySimilar" + SimilarByIdQuerySuffix = "ById" + SimilarByEmbeddingQuerySuffix = "ByEmbedding" + SimilarQueryResultTypeSuffix = "WithDistance" + SimilarQueryDistanceFieldName = "vector_distance" + SimilarSearchMetricEuclidean = "euclidean" + SimilarSearchMetricDotProduct = "dotproduct" + SimilarSearchMetricCosine = "cosine" ) // Schema represents a valid GraphQL schema @@ -1388,7 +1390,9 @@ func (f *field) IDArgValue() (xids map[string]string, uid uint64, err error) { if (idField == nil || arg.Name != idField.Name()) && (passwordField == nil || arg.Name != passwordField.Name()) && (queryType(f.field.Name, nil) != SimilarByIdQuery || - (arg.Name != SimilarTopKArgName && arg.Name != SimilarByArgName && arg.Name != "filter")) { + (arg.Name != SimilarTopKArgName && arg.Name != SimilarByArgName && + arg.Name != SimilarEfArgName && arg.Name != SimilarDistanceThresholdArgName && + arg.Name != "filter")) { xidArgName = arg.Name }