Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions NLP_and_backend/search_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,28 @@ def semantic_search_sbert(query:str, command_embeddings, model:str):
query_embedding = embedder.encode(query,
convert_to_numpy=True,
normalize_embeddings=True)

embedding_command_ids = command_embeddings[:, 0, :]
embedding_titles = command_embeddings[:, 1, :]

scores_ids = sentence_transformers.util.semantic_search(query_embedding,
embedding_command_ids,
top_k=20000, # this k needs to be very big
score_function=sentence_transformers.util.dot_score
)[0] # index 0 because we have only one query

return sentence_transformers.util.semantic_search(query_embedding,
command_embeddings,
scores_titles = sentence_transformers.util.semantic_search(query_embedding,
embedding_titles,
top_k=20000, # this k needs to be very big
score_function=sentence_transformers.util.dot_score
)[0] # index 0 because we have only one query

combined_scores = []
for i in range(len(scores_ids)):
combined_scores.append({'corpus_id': scores_ids[i]['corpus_id'],
'score': max(scores_ids[i]['score'], scores_titles[i]['score'])})

return combined_scores

# cosine_scores = sentence_transformers.util.dot_score(
# query_embedding, command_embeddings)
Expand All @@ -36,8 +52,7 @@ def semantic_search(query:str, command_embeddings, method:str, model:str):

command_dict_list: a list of commands,
each a dict:
{"key":str, "command":str, "when":str, "to-ebd": str}
to-ebd is the string to embed the command itself or its label
{"command_id":str, "command_title":str, "command_id_normalized":str}

"""

Expand Down Expand Up @@ -72,8 +87,8 @@ def combine_results(scores:dict[list], command_dict_list, k, p):
cutoff_index = filter_results([s['score'] for s in scores], k, p)

for score in scores[:cutoff_index]:
command_id = score['corpus_id']
command = command_dict_list[command_id]['command']
command_idx = score['corpus_id']
command = command_dict_list[command_idx]['command_id']
results.append((command, score['score']))
return results

Expand Down