Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions index_alias_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,15 @@ func MultiSearch(ctx context.Context, req *SearchRequest, params *multiSearchPar
searchStart := time.Now()
asyncResults := make(chan *asyncSearchResult, len(indexes))

var preSearchData map[string]map[string]interface{}
var rescorer *rescorer
var fusionKnnHits search.DocumentMatchCollection
if params != nil {
preSearchData = params.preSearchData
rescorer = params.rescorer
fusionKnnHits = params.fusionKnnHits
}

var reverseQueryExecution bool
if req.SearchBefore != nil {
reverseQueryExecution = true
Expand All @@ -1006,8 +1015,8 @@ func MultiSearch(ctx context.Context, req *SearchRequest, params *multiSearchPar
waitGroup.Add(len(indexes))
for _, in := range indexes {
var payload map[string]interface{}
if params.preSearchData != nil {
payload = params.preSearchData[in.Name()]
if preSearchData != nil {
payload = preSearchData[in.Name()]
}
go searchChildIndex(in, createChildSearchRequest(req, payload))
}
Expand Down Expand Up @@ -1047,9 +1056,9 @@ func MultiSearch(ctx context.Context, req *SearchRequest, params *multiSearchPar
}
}

if params.rescorer != nil {
sr.Hits, sr.Total, sr.MaxScore = params.rescorer.rescore(sr.Hits, params.fusionKnnHits)
params.rescorer.restoreSearchRequest()
if rescorer != nil {
sr.Hits, sr.Total, sr.MaxScore = rescorer.rescore(sr.Hits, fusionKnnHits)
rescorer.restoreSearchRequest()
}

sr.Hits = hitsInCurrentPage(req, sr.Hits)
Expand Down
25 changes: 9 additions & 16 deletions index_alias_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,7 @@ func TestMultiSearchNoError(t *testing.T) {
MaxScore: 2.0,
}

multiSearchParams := &multiSearchParams{nil, nil, nil}
results, err := MultiSearch(context.Background(), sr, multiSearchParams, ei1, ei2)
results, err := MultiSearch(context.Background(), sr, nil, ei1, ei2)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -626,8 +625,7 @@ func TestMultiSearchSomeError(t *testing.T) {
}}
ei2 := &stubIndex{name: "ei2", err: fmt.Errorf("deliberate error")}
sr := NewSearchRequest(NewTermQuery("test"))
multiSearchParams := &multiSearchParams{nil, nil, nil}
res, err := MultiSearch(context.Background(), sr, multiSearchParams, ei1, ei2)
res, err := MultiSearch(context.Background(), sr, nil, ei1, ei2)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
Expand All @@ -654,8 +652,7 @@ func TestMultiSearchAllError(t *testing.T) {
ei1 := &stubIndex{name: "ei1", err: fmt.Errorf("deliberate error")}
ei2 := &stubIndex{name: "ei2", err: fmt.Errorf("deliberate error")}
sr := NewSearchRequest(NewTermQuery("test"))
multiSearchParams := &multiSearchParams{nil, nil, nil}
res, err := MultiSearch(context.Background(), sr, multiSearchParams, ei1, ei2)
res, err := MultiSearch(context.Background(), sr, nil, ei1, ei2)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
Expand Down Expand Up @@ -711,8 +708,7 @@ func TestMultiSearchSecondPage(t *testing.T) {
checkRequest: checkRequest,
}
sr := NewSearchRequestOptions(NewTermQuery("test"), 10, 10, false)
multiSearchParams := &multiSearchParams{nil, nil, nil}
_, err := MultiSearch(context.Background(), sr, multiSearchParams, ei1, ei2)
_, err := MultiSearch(context.Background(), sr, nil, ei1, ei2)
if err != nil {
t.Errorf("unexpected error %v", err)
}
Expand Down Expand Up @@ -791,8 +787,7 @@ func TestMultiSearchTimeout(t *testing.T) {
defer cancel()
query := NewTermQuery("test")
sr := NewSearchRequest(query)
multiSearchParams := &multiSearchParams{nil, nil, nil}
res, err := MultiSearch(ctx, sr, multiSearchParams, ei1, ei2)
res, err := MultiSearch(ctx, sr, nil, ei1, ei2)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
Expand All @@ -812,7 +807,7 @@ func TestMultiSearchTimeout(t *testing.T) {
// now run a search again with an absurdly low timeout (should timeout)
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
res, err = MultiSearch(ctx, sr, multiSearchParams, ei1, ei2)
res, err = MultiSearch(ctx, sr, nil, ei1, ei2)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
Expand All @@ -839,7 +834,7 @@ func TestMultiSearchTimeout(t *testing.T) {
// now run a search again with a normal timeout, but cancel it first
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
cancel()
res, err = MultiSearch(ctx, sr, multiSearchParams, ei1, ei2)
res, err = MultiSearch(ctx, sr, nil, ei1, ei2)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
Expand Down Expand Up @@ -974,8 +969,7 @@ func TestMultiSearchTimeoutPartial(t *testing.T) {
MaxScore: 2.0,
}

multiSearchParams := &multiSearchParams{nil, nil, nil}
res, err := MultiSearch(ctx, sr, multiSearchParams, ei1, ei2, ei3)
res, err := MultiSearch(ctx, sr, nil, ei1, ei2, ei3)
if err != nil {
t.Fatalf("expected no err, got %v", err)
}
Expand Down Expand Up @@ -1233,8 +1227,7 @@ func TestMultiSearchCustomSort(t *testing.T) {
MaxScore: 3.0,
}

multiSearchParams := &multiSearchParams{nil, nil, nil}
results, err := MultiSearch(context.Background(), sr, multiSearchParams, ei1, ei2)
results, err := MultiSearch(context.Background(), sr, nil, ei1, ei2)
if err != nil {
t.Error(err)
}
Expand Down
Loading