diff --git a/index.go b/index.go index 58543f2..cf7dc32 100644 --- a/index.go +++ b/index.go @@ -331,6 +331,15 @@ func (idx *faissIndex) Reconstruct(key int64) (recons []float32, err error) { func (idx *faissIndex) ReconstructBatch(keys []int64, recons []float32) ([]float32, error) { var err error n := int64(len(keys)) + if recons == nil { + recons = make([]float32, n*int64(idx.D())) + } + + // exit in case of invalid input + if n == 0 || len(recons) != int(n)*idx.D() { + return nil, fmt.Errorf("invalid input") + } + if c := C.faiss_Index_reconstruct_batch( idx.idx, C.idx_t(n),