Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 59 additions & 27 deletions graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,9 @@ func rewriteAsSimilarByIdQuery(
distanceFormula := "math(sqrt((v2 - v1) dot (v2 - v1)))" // default - euclidean

if metric == schema.SimilarSearchMetricDotProduct {
distanceFormula = "math((1.0 - (v1 dot v2)) /2.0)"
distanceFormula = "math(1.0 - (v1 dot v2))"
} else if metric == schema.SimilarSearchMetricCosine {
distanceFormula = "math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)"
distanceFormula = "math(1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) )))"
}

// First generate the query to fetch the uid
Expand Down Expand Up @@ -698,6 +698,32 @@ func rewriteAsSimilarByIdQuery(
// v2 as Product.embedding
// distance as math((v2 - v1) dot (v2 - v1))
// }
similarToArgs := []dql.Arg{
{
Value: pred,
},
{
Value: fmt.Sprintf("%v", topK),
},
{
Value: "val(v1)",
},
}

// Add optional ef parameter if provided (using "ef: value" format)
if ef := query.ArgValue(schema.SimilarEfArgName); ef != nil {
similarToArgs = append(similarToArgs,
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarEfArgName, ef)},
)
}

// Add optional distance_threshold parameter if provided (using "distance_threshold: value" format)
if dt := query.ArgValue(schema.SimilarDistanceThresholdArgName); dt != nil {
similarToArgs = append(similarToArgs,
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarDistanceThresholdArgName, dt)},
)
}

similarQuery := &dql.GraphQuery{
Attr: "var",
Children: []*dql.GraphQuery{
Expand All @@ -712,17 +738,7 @@ func rewriteAsSimilarByIdQuery(
},
Func: &dql.Function{
Name: "similar_to",
Args: []dql.Arg{
{
Value: pred,
},
{
Value: fmt.Sprintf("%v", topK),
},
{
Value: "val(v1)",
},
},
Args: similarToArgs,
},
}

Expand Down Expand Up @@ -811,10 +827,10 @@ func rewriteAsSimilarByEmbeddingQuery(
distanceFormula := "math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))" // default = euclidean

if metric == schema.SimilarSearchMetricDotProduct {
distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)"
distanceFormula = "math(1.0 - (($search_vector) dot v2))"
Comment thread
matthewmcneely marked this conversation as resolved.
} else if metric == schema.SimilarSearchMetricCosine {
distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
" * (v2 dot v2) ) )) / 2.0)"
distanceFormula = "math(1.0 - ((($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
" * (v2 dot v2) )))"
}

// Save vectorString as a query variable, $search_vector
Expand All @@ -834,19 +850,35 @@ func rewriteAsSimilarByEmbeddingQuery(
// Create similar_to as the root function, passing $search_vector as
// the search vector
dgQuery[0].Attr = "var"
similarToArgs := []dql.Arg{
{
Value: pred,
},
{
Value: fmt.Sprintf("%v", topK),
},
{
Value: "$search_vector",
},
}

// Add optional ef parameter if provided (using "ef: value" format)
if ef := query.ArgValue(schema.SimilarEfArgName); ef != nil {
similarToArgs = append(similarToArgs,
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarEfArgName, ef)},
)
}

// Add optional distance_threshold parameter if provided (using "distance_threshold: value" format)
if dt := query.ArgValue(schema.SimilarDistanceThresholdArgName); dt != nil {
similarToArgs = append(similarToArgs,
dql.Arg{Value: fmt.Sprintf("%s: %v", schema.SimilarDistanceThresholdArgName, dt)},
)
}

dgQuery[0].Func = &dql.Function{
Name: "similar_to",
Args: []dql.Arg{
{
Value: pred,
},
{
Value: fmt.Sprintf("%v", topK),
},
{
Value: "$search_vector",
},
},
Args: similarToArgs,
}

// Compute the euclidean distance between the neighbor
Expand Down
176 changes: 172 additions & 4 deletions graphql/resolve/query_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3432,7 +3432,7 @@
}
var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) {
v2 as ProjectCosine.description_v
distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)
distance as math(1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) )))
}
querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) {
ProjectCosine.id : ProjectCosine.id
Expand All @@ -3457,7 +3457,7 @@
query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) {
v2 as ProjectCosine.description_v
distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0)
distance as math(1.0 - ((($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) )))
}
querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) {
ProjectCosine.id : ProjectCosine.id
Expand Down Expand Up @@ -3487,7 +3487,7 @@
}
var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) {
v2 as ProjectDotProduct.description_v
distance as math((1.0 - (v1 dot v2)) /2.0)
distance as math(1.0 - (v1 dot v2))
}
querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) {
ProjectDotProduct.id : ProjectDotProduct.id
Expand All @@ -3512,7 +3512,7 @@
query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) {
v2 as ProjectDotProduct.description_v
distance as math(( 1.0 - (($search_vector) dot v2)) /2.0)
distance as math(1.0 - (($search_vector) dot v2))
}
querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
ProjectDotProduct.id : ProjectDotProduct.id
Expand All @@ -3522,3 +3522,171 @@
ProjectDotProduct.vector_distance : val(distance)
}
}

- name: query similar_to with ef parameter
gqlquery: |
query {
querySimilarProductByEmbedding(by: productVector, topK: 5, vector: [0.1, 0.2, 0.3, 0.4, 0.5], ef: 64) {
id
title
productVector
}
}

dgquery: |-
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(Product.productVector, 5, $search_vector, ef: 64)) @filter(type(Product)) {
v2 as Product.productVector
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
}
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Product.title : Product.title
Product.productVector : Product.productVector
dgraph.uid : uid
Product.vector_distance : val(distance)
}
}

- name: query similar_to with distance_threshold parameter
gqlquery: |
query {
querySimilarProductByEmbedding(by: productVector, topK: 10, vector: [0.1, 0.2, 0.3, 0.4, 0.5], distance_threshold: 0.5) {
id
title
productVector
}
}

dgquery: |-
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(Product.productVector, 10, $search_vector, distance_threshold: 0.5)) @filter(type(Product)) {
v2 as Product.productVector
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
}
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Product.title : Product.title
Product.productVector : Product.productVector
dgraph.uid : uid
Product.vector_distance : val(distance)
}
}

- name: query similar_to with both ef and distance_threshold parameters
gqlquery: |
query {
querySimilarProductByEmbedding(by: productVector, topK: 8, vector: [0.1, 0.2, 0.3, 0.4, 0.5], ef: 128, distance_threshold: 0.75) {
id
title
productVector
}
}

dgquery: |-
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(Product.productVector, 8, $search_vector, ef: 128, distance_threshold: 0.75)) @filter(type(Product)) {
v2 as Product.productVector
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
}
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Product.title : Product.title
Product.productVector : Product.productVector
dgraph.uid : uid
Product.vector_distance : val(distance)
}
}

- name: query vector by id with ef parameter
gqlquery: |
query {
querySimilarProductById(by: productVector, topK: 5, id: "0x1", ef: 64) {
id
title
productVector
}
}

dgquery: |-
query {
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
vec as Product.productVector
}
var() {
v1 as max(val(vec))
}
var(func: similar_to(Product.productVector, 5, val(v1), ef: 64)) {
v2 as Product.productVector
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
}
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Product.title : Product.title
Product.productVector : Product.productVector
dgraph.uid : uid
Product.vector_distance : val(distance)
}
}

- name: query vector by id with distance_threshold parameter
gqlquery: |
query {
querySimilarProductById(by: productVector, topK: 10, id: "0x1", distance_threshold: 0.5) {
id
title
productVector
}
}

dgquery: |-
query {
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
vec as Product.productVector
}
var() {
v1 as max(val(vec))
}
var(func: similar_to(Product.productVector, 10, val(v1), distance_threshold: 0.5)) {
v2 as Product.productVector
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
}
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Product.title : Product.title
Product.productVector : Product.productVector
dgraph.uid : uid
Product.vector_distance : val(distance)
}
}

- name: query vector by id with both ef and distance_threshold parameters
gqlquery: |
query {
querySimilarProductById(by: productVector, topK: 8, id: "0x1", ef: 128, distance_threshold: 0.75) {
id
title
productVector
}
}

dgquery: |-
query {
var(func: eq(Product.id, "0x1")) @filter(type(Product)) {
vec as Product.productVector
}
var() {
v1 as max(val(vec))
}
var(func: similar_to(Product.productVector, 8, val(v1), ef: 128, distance_threshold: 0.75)) {
v2 as Product.productVector
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
}
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Product.title : Product.title
Product.productVector : Product.productVector
dgraph.uid : uid
Product.vector_distance : val(distance)
}
}
37 changes: 37 additions & 0 deletions graphql/schema/gqlschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,25 @@ func addSimilarByEmbeddingQuery(schema *ast.Schema, defn *ast.Definition) {
NonNull: true,
},
})

// Accept optional ef parameter for HNSW search
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
Name: SimilarEfArgName,
Type: &ast.Type{
NamedType: "Int",
NonNull: false,
},
})

// Accept optional distance_threshold parameter for filtering results
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
Name: SimilarDistanceThresholdArgName,
Type: &ast.Type{
NamedType: "Float",
NonNull: false,
},
})

addFilterArgument(schema, qry)

schema.Query.Fields = append(schema.Query.Fields, qry)
Expand Down Expand Up @@ -2197,6 +2216,24 @@ func addSimilarByIdQuery(schema *ast.Schema, defn *ast.Definition,
},
})

// Accept optional ef parameter for HNSW search
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
Name: SimilarEfArgName,
Type: &ast.Type{
NamedType: "Int",
NonNull: false,
},
})

// Accept optional distance_threshold parameter for filtering results
qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{
Name: SimilarDistanceThresholdArgName,
Type: &ast.Type{
NamedType: "Float",
NonNull: false,
},
})

addFilterArgument(schema, qry)
schema.Query.Fields = append(schema.Query.Fields, qry)
}
Expand Down
Loading
Loading