Skip to content
This repository was archived by the owner on Sep 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions api_dql.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package modusdb

import "fmt"
import (
"fmt"
)

type QueryFunc func() string

Expand Down Expand Up @@ -31,6 +33,7 @@ const (

funcUid = `func: uid(%d)`
funcEq = `func: eq(%s, %s)`
// funcSimilarTo = `func: similar_to(%s, %d, "%s")`
)

func buildUidQuery(gid uint64) QueryFunc {
Expand All @@ -39,12 +42,24 @@ func buildUidQuery(gid uint64) QueryFunc {
}
}

func buildEqQuery(key, value any) QueryFunc {
func buildEqQuery(key string, value any) QueryFunc {
return func() string {
return fmt.Sprintf(funcEq, key, value)
}
}

// func buildVecSearchQuery(indexAttr string, topK int64, vec []float32) QueryFunc {
// vecStrArr := make([]string, len(vec))
// for i := range vec {
// vecStrArr[i] = strconv.FormatFloat(float64(vec[i]), 'f', -1, 32)
// }
// vecStr := strings.Join(vecStrArr, ",")
// vecStr = "[" + vecStr + "]"
// return func() string {
// return fmt.Sprintf(funcSimilarTo, indexAttr, topK, vecStr)
// }
// }

func formatObjQuery(qf QueryFunc, extraFields string) string {
return fmt.Sprintf(objQuery, qf(), extraFields)
}
Expand Down
50 changes: 36 additions & 14 deletions api_mutate_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@
}
if jsonToDbTags[jsonName] != nil {
constraint := jsonToDbTags[jsonName].constraint
if constraint == "unique" || constraint == "term" {
uniqueConstraintFound = true
u.Directive = pb.SchemaUpdate_INDEX
if constraint == "unique" {
u.Tokenizer = []string{"exact"}
} else {
u.Tokenizer = []string{"term"}
}
if constraint == "vector" && valType != pb.Posting_VFLOAT {
return fmt.Errorf("vector index can only be applied to []float values")
}
if uniqueConstraintFound {

Check failure on line 96 in api_mutate_helper.go

View workflow job for this annotation

GitHub Actions / ci-go-lint

SA9003: empty branch (staticcheck)

Check failure on line 96 in api_mutate_helper.go

View workflow job for this annotation

GitHub Actions / ci-go-lint

SA9003: empty branch (staticcheck)

}
uniqueConstraintFound = addIndex(u, constraint, uniqueConstraintFound)
}

sch.Preds = append(sch.Preds, u)
Expand Down Expand Up @@ -162,13 +160,10 @@
ValueType: valType,
}
if indexes[pred] != "" {
u.Directive = pb.SchemaUpdate_INDEX
index := indexes[pred]
if index == "unique" {
u.Tokenizer = []string{"exact"}
} else if index == "term" {
u.Tokenizer = []string{"term"}
if indexes[pred] == "vector" && valType != pb.Posting_VFLOAT {
return fmt.Errorf("vector index can only be applied to []float values")
}
addIndex(u, indexes[pred], false)
}

sch.Preds = append(sch.Preds, u)
Expand Down Expand Up @@ -289,3 +284,30 @@

return gid, nil
}

func addIndex(u *pb.SchemaUpdate, index string, uniqueConstraintExists bool) bool {
u.Directive = pb.SchemaUpdate_INDEX
switch index {
case "unique":
u.Tokenizer = []string{"exact"}
uniqueConstraintExists = true
case "term":
u.Tokenizer = []string{"term"}
uniqueConstraintExists = true
case "vector":
u.IndexSpecs = []*pb.VectorIndexSpec{
{
Name: "hnsw",
Options: []*pb.OptionPair{
{
Key: "metric",
Value: "cosine",
},
},
},
}
default:
return uniqueConstraintExists
}
return uniqueConstraintExists
}
2 changes: 1 addition & 1 deletion api_query_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac
return 0, nil, fmt.Errorf("invalid unique field type")
}

if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" {
if jsonToDbTag[cf.Key] == nil || jsonToDbTag[cf.Key].constraint == "" {
return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key)
}

Expand Down
146 changes: 145 additions & 1 deletion api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ func TestRawAPIs(t *testing.T) {
require.Equal(t, gid, getGid)
require.Equal(t, "B", maps["name"])

//TODO figure out why it comes back as a flaot64
//TODO figure out why it comes back as a float64
// schema and value are correctly set to int, so its a query side issue
require.Equal(t, float64(20), maps["age"])
require.Equal(t, "123", maps["clerk_id"])
Expand All @@ -565,3 +565,147 @@ func TestRawAPIs(t *testing.T) {
require.NoError(t, err)
require.Equal(t, gid, deleteGid)
}

func TestVectorIndexInsert(t *testing.T) {
ctx := context.Background()
db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir()))
require.NoError(t, err)
defer db.Close()

db1, err := db.CreateNamespace()
require.NoError(t, err)

require.NoError(t, db1.DropData(ctx))

_, err = modusdb.RawCreate(db, map[string]any{
"name": "B",
"age": 20,
"clerk_id": "123",
"vec": []float64{1.0, 2.0, 3.0},
}, map[string]string{
"clerk_id": "unique",
"vec": "vector",
}, db1.ID())

require.NoError(t, err)

query := `{
me(func: has(name)) {
uid
name
age
clerk_id
vec
}
}`
resp, err := db1.Query(ctx, query)
require.NoError(t, err)
require.JSONEq(t, `{"me":[{"uid":"0x2","name":"B","age":20,"clerk_id":"123","vec":[1,2,3]}]}`,
string(resp.GetJson()))

}

func TestVectorIndexSearchUntyped(t *testing.T) {
ctx := context.Background()
db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir()))
require.NoError(t, err)
defer db.Close()

db1, err := db.CreateNamespace()
require.NoError(t, err)

require.NoError(t, db1.DropData(ctx))

vectors := [][]float64{
{1.0, 2.0, 3.0}, // Sequential
{4.0, 5.0, 6.0}, // Sequential continued
{7.0, 8.0, 9.0}, // Sequential continued
{0.1, 0.2, 0.3}, // Small decimals
{1.5, 2.5, 3.5}, // Half steps
{1.0, 2.0, 3.0}, // Duplicate
{10.0, 20.0, 30.0}, // Tens
{0.5, 1.0, 1.5}, // Half increments
{2.2, 4.4, 6.6}, // Multiples of 2.2
{3.3, 6.6, 9.9}, // Multiples of 3.3
}

for _, vec := range vectors {
_, err = modusdb.RawCreate(db, map[string]any{
"vec": vec,
}, map[string]string{
"vec": "vector",
}, db1.ID())
require.NoError(t, err)
}

const query = `
{
vector(func: similar_to(vec, 5, "[4.1,5.1,6.1]")) {
vec
}
}`

resp, err := db1.Query(ctx, query)
require.NoError(t, err)
require.JSONEq(t, `{
"vector":[
{"vec":[4,5,6]},
{"vec":[7,8,9]},
{"vec":[1.5,2.5,3.5]},
{"vec":[10,20,30]},
{"vec":[2.2,4.4,6.6]}
]
}`, string(resp.GetJson()))
}

type Document struct {
Gid uint64 `json:"gid,omitempty"`
Text string `json:"text,omitempty"`
TextVec []float32 `json:"textVec,omitempty" db:"constraint=vector"`
}

func TestVectorIndexSearchTyped(t *testing.T) {
ctx := context.Background()
db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir()))
require.NoError(t, err)
defer db.Close()

db1, err := db.CreateNamespace()
require.NoError(t, err)

require.NoError(t, db1.DropData(ctx))

documents := []Document{
{Text: "apple", TextVec: []float32{1.0, 0.0, 0.0}},
{Text: "banana", TextVec: []float32{0.0, 1.0, 0.0}},
{Text: "carrot", TextVec: []float32{0.0, 0.0, 1.0}},
{Text: "dog", TextVec: []float32{1.0, 1.0, 0.0}},
{Text: "elephant", TextVec: []float32{0.0, 1.0, 1.0}},
{Text: "fox", TextVec: []float32{1.0, 0.0, 1.0}},
{Text: "gorilla", TextVec: []float32{1.0, 1.0, 1.0}},
}

for _, doc := range documents {
_, _, err = modusdb.Create(db, &doc, db1.ID())
require.NoError(t, err)
}

const query = `
{
documents(func: similar_to(Document.textVec, 5, "[0.1,0.1,0.1]")) {
Document.text
}
}`

resp, err := db1.Query(ctx, query)
require.NoError(t, err)
require.JSONEq(t, `{
"documents":[
{"Document.text":"apple"},
{"Document.text":"dog"},
{"Document.text":"elephant"},
{"Document.text":"fox"},
{"Document.text":"gorilla"}
]
}`, string(resp.GetJson()))
}
11 changes: 11 additions & 0 deletions api_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/dgraph-io/dgo/v240/protos/api"
"github.com/dgraph-io/dgraph/v24/protos/pb"
"github.com/dgraph-io/dgraph/v24/types"
"github.com/dgraph-io/dgraph/v24/x"
"github.com/twpayne/go-geom"
"github.com/twpayne/go-geom/encoding/wkb"
Expand Down Expand Up @@ -110,6 +111,16 @@ func valueToApiVal(v any) (*api.Value, error) {
return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: float64(val)}}, nil
case float64:
return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}}, nil
case []float32:
return &api.Value{Val: &api.Value_Vfloat32Val{
Vfloat32Val: types.FloatArrayAsBytes(val)}}, nil
case []float64:
float32Slice := make([]float32, len(val))
for i, v := range val {
float32Slice[i] = float32(v)
}
return &api.Value{Val: &api.Value_Vfloat32Val{
Vfloat32Val: types.FloatArrayAsBytes(float32Slice)}}, nil
case []byte:
return &api.Value{Val: &api.Value_BytesVal{BytesVal: val}}, nil
case time.Time:
Expand Down
Loading