Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b120554
add code for bm25 with rrf. slight restructure.
Oct 21, 2023
a8679a8
add simple sentence query decomposition
Oct 21, 2023
c0e6ffa
add runs with decomposition + rrf and rm3
Oct 21, 2023
81da413
add runs with decomposition + rrf and rm3
JMVCoelho Oct 21, 2023
81fb401
Merge branch 'main' of https://github.com/JMVCoelho/llms-project
JMVCoelho Oct 21, 2023
f4a25f6
Update README.md
JMVCoelho Oct 21, 2023
144de78
zero-shot-gpt-baseline
aprameyabharadwaj Oct 23, 2023
6d942e3
fix
aprameyabharadwaj Oct 23, 2023
b067b8d
add running scripts
JMVCoelho Oct 28, 2023
ea9062b
add precision metric. update dense retrieval run script to support se…
JMVCoelho Oct 29, 2023
4f5de72
Added code for GPT-4 for query decomposition
Oct 29, 2023
a27631a
Merge branch 'main' into cdgamaro
Oct 29, 2023
77d46b0
Merge pull request #1 from JMVCoelho/cdgamaro-query-decomposition
chantal-rose Oct 29, 2023
6690fc3
Added basline metrics
rosaboyle Oct 30, 2023
ea1c029
Decomposed Queries
Oct 31, 2023
5427221
Runs
Oct 31, 2023
e5ad958
Update DENSE.md
rosaboyle Oct 31, 2023
1f173b9
Saving indices
Oct 31, 2023
cb99701
Undoing edits in bm25
Oct 31, 2023
fb751ce
Dense query decomposition
Oct 31, 2023
85adf04
Dense RRF
Oct 31, 2023
126efa7
Merge branch 'main' of https://github.com/JMVCoelho/llms-project
Oct 31, 2023
d48a112
add gpt4 reranker
JMVCoelho Nov 7, 2023
b0b9483
runs
JMVCoelho Nov 7, 2023
a02e8d7
Merge branch 'main' of https://github.com/JMVCoelho/llms-project
JMVCoelho Nov 7, 2023
79dd984
pointwise_code
aprameyabharadwaj Nov 22, 2023
b4f833e
Dense RRF Inference Ablations
Nov 28, 2023
e05b163
LLM Decomposition results
Nov 28, 2023
8e7c3ee
update gpt4 rerank with retries and evaluation
JMVCoelho Dec 2, 2023
7d754bc
Merge branch 'main' of https://github.com/JMVCoelho/llms-project
JMVCoelho Dec 2, 2023
aea1e99
Merge pull request #2 from JMVCoelho/ambharad
aprameyabharadwaj Dec 2, 2023
f4635b7
Updated prompt for query decomposition
Dec 3, 2023
4dd02cf
Merge remote-tracking branch 'origin/main'
Dec 3, 2023
3876455
Removed openai key
Dec 3, 2023
90e3d64
Decomposed DistilBert
Dec 4, 2023
afeecb3
Merge branch 'main' of https://github.com/JMVCoelho/llms-project
Dec 4, 2023
a883e11
Update README.md
JMVCoelho Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
dense_index_*
output_dense_*
runs_dense/*
scripts/runs_dense/*
runs
.idea
.python-version
anserini_docs
anserini_indicies
data
.DS_Store
dense_hs.*
best_models.py
Expand Down Expand Up @@ -171,3 +177,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

dense_models_backup/
*.index
29 changes: 29 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false,
// "args": [ "input","--corpus","/home/ddo/CMU/PLLM/TREC-TOT/corpus.jsonl","--fields","text" ,
// // "--delimiter","\\n",
// "--shard-id","0" ,
// "--shard-num","1" ,
// "output","--embeddings","../output_dense_enc_1" ,
// "--to-faiss" ,
// "encoder","--encoder","/home/ddo/CMU/PLLM/llms-project/dense_models/baseline_distilbert/model" ,
// "--fields","text",
// "--batch", "32" ]
// "args": []
"args": ["--epochs","20","--lr","6e-05","--weight_decay","0.01","--model_dir","dense_models/baseline_distilbert_0","--run_id","baseline_distilbert_0","--model_or_checkpoint","distilbert-base-uncased","--embed_size","768","--batch_size","10","--encode_batch_size","128","--data_path","/home/ddo/CMU/PLLM/TREC-TOT","--negatives_path","/home/ddo/CMU/PLLM/TREC-TOT/negatives/bm25_negatives","--negatives_out","/home/ddo/CMU/PLLM/TREC-TOT/negatives/baseline_distilbert_0_negatives","--query","title_text","--device","cuda"]
// "args": ["--epochs","20","--lr","6e-05","--weight_decay","0.01","--model_dir","dense_models/baseline_distilbert_0","--run_id","baseline_distilbert_0","--model_or_checkpoint","distilbert-base-uncased","--embed_size","768","--batch_size","10","--encode_batch_size","128","--data_path","/home/ddo/CMU/PLLM/TREC-TOT","--negatives_path","/home/ddo/CMU/PLLM/TREC-TOT/negatives/bm25_negatives","--negatives_out","/home/ddo/CMU/PLLM/TREC-TOT/negatives/baseline_distilbert_0_negatives","--query","title_text","--device","cuda"]
// "args": ["--epochs","20","--lr","6e-05","--weight_decay","0.01","--model_dir","dense_models/baseline_distilbert_0","--run_id","baseline_distilbert_0","--model_or_checkpoint","distilbert-base-uncased","--embed_size","768","--batch_size","10","--encode_batch_size","128","--data_path","/home/ddo/CMU/PLLM/TREC-TOT","--negatives_path","/home/ddo/CMU/PLLM/TREC-TOT/negatives/bm25_negatives","--negatives_out","/home/ddo/CMU/PLLM/TREC-TOT/negatives/baseline_distilbert_0_negatives","--query","title_text","--device","cuda"]
}
]
}
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Benchmarks for TREC-ToT (2023)

NOTE: code is activelly being re-organized!

The following benchmarks (& runs) are available:


| Benchmark | Runfiles | Dev-DCG | Dev-Success@1000 | Dev-MRR |
|----------------------|----------|----------|-----------------|-------|
| [BM25](BM25.md) (k1=0.8, b=1.0) | [train](runs/bm25/train.run), [dev](runs/bm25/dev.run) | 0.1314 | 0.4067 | 0.0881 |
| [Dense Retrieval (SBERT)](DENSE.md) (Distilbert) | [train](runs/distilbert/train.run), [dev](runs/distilbert/dev.run) | 0.1627 | 0.6600 | 0.0743 |
| [GPT-4](GPT4.md)* | [train](runs/gpt4/train.run), [dev](runs/gpt4/dev.run) | 0.2407 | 0.3200 | 0.2180 |
| [BM25](docs/BM25.md) (k1=0.8, b=1.0) | [train](runs/bm25/train.run), [dev](runs/bm25/dev.run) | 0.1314 | 0.4067 | 0.0881 |
| [Dense Retrieval (SBERT)](docs/DENSE.md) (Distilbert) | [train](runs/distilbert/train.run), [dev](runs/distilbert/dev.run) | 0.1627 | 0.6600 | 0.0743 |
| [GPT-4](docs/GPT4.md)* | [train](runs/gpt4/train.run), [dev](runs/gpt4/dev.run) | 0.2407 | 0.3200 | 0.2180 |

*: GPT-4 generates 20 candidates at most. See [GPT4](GPT4.md) for more details.

Expand All @@ -18,7 +20,7 @@ The following benchmarks (& runs) are available:
## optional: create new environment using py-env virtual-env
## pyenv virtualenv 3.8.11 trec-tot-benchmarks
# install requirements
pip install ir_datasets sentence-transformers==2.2.2 pyserini==0.20.0 pytrec_eval faiss-cpu==1.6.5
pip install ir_datasets sentence-transformers==2.2.2 pyserini==0.20.0 pytrec_eval faiss-cpu==1.6.5 ranx==0.3.7
```

After downloading the files (see guidelines), set DATA_PATH to the folder which
Expand All @@ -39,4 +41,4 @@ Quick test to see if data is setup properly:
python tot.py
```
The command above should print the correct number of train/dev queries and the number of documents
in the corpus, along with example queries and documents.
in the corpus, along with example queries and documents.
2 changes: 1 addition & 1 deletion bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

log = logging.getLogger(__name__)

METRICS = "recall_10,recall_100,recall_1000,ndcg_cut_10,ndcg_cut_100,ndcg_cut_1000,recip_rank"
METRICS = "P_1,recall_10,recall_100,recall_1000,ndcg_cut_10,ndcg_cut_100,ndcg_cut_1000,recip_rank"


def create_index(dataset, field_to_index, dest_folder, index):
Expand Down
121 changes: 121 additions & 0 deletions bm25_with_rrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import json
import argparse
import os
import pandas as pd
import logging


from tqdm import tqdm
from typing import Dict
from pyserini.search.lucene import LuceneSearcher
from pyserini.trectools import TrecRun
from pyserini.fusion import reciprocal_rank_fusion
from ranx import Qrels, Run, evaluate

from modules import llm_based_decomposition, sentence_decomposition
import tot
import ir_datasets
from src import utils
import pytrec_eval

METRICS = "P_1,recall_10,recall_100,recall_1000,ndcg_cut_10,ndcg_cut_100,ndcg_cut_1000,recip_rank"

log = logging.getLogger(__name__)

def main():
parser = argparse.ArgumentParser()
# Path to indexes directory
parser.add_argument("--index_name", default="bm25_0.8_1.0", help="name of index")

parser.add_argument("--decomposition_method", default="llm", help="how to decompose")

parser.add_argument("--data_path", default="./data", help="location to dataset")

parser.add_argument("--split", choices={"train", "dev", "test"}, default="dev", help="split to run")

parser.add_argument("--index_path", default="./anserini_indicies", help="path to store (all) indices")

parser.add_argument("--metrics", default=METRICS, help="csv - metrics to evaluate")

parser.add_argument("--param_k1", default=0.8, type=float, help="param: k1 for BM25")

parser.add_argument("--param_b", default=1.0, type=float, help="param: b for BM25")

# BM25 parameters
parser.add_argument('--K', type=int, help='retrieve top K documents', default=1000)

# Binary flags to enable or disable ranking methodss
parser.add_argument('--rm3', type=str, help='enable or disable rm3', choices=['y', 'n'], default='n')

# Run number
parser.add_argument('--run_number', type=int, help='run number', default=1)

# Output options and directory
parser.add_argument('--output_dir', type=str, help='path to output_dir', default="runs/")
args = parser.parse_args()

tot.register(args.data_path)

irds_name = "trec-tot:" + args.split
dataset = ir_datasets.load(irds_name)
if args.decomposition_method == "llm":
queries_expanded = llm_based_decomposition(dataset, f"{args.data_path}/decomposed_queries")
else:
queries_expanded = sentence_decomposition(dataset, f"{args.data_path}/decomposed_queries")

queries = json.load(open(queries_expanded))

run_save_folder = f'{args.output_dir}BM25-RRF'
if args.decomposition_method == "llm":
run_save_folder += f'-llm'

run_save_folder += f'-RM3-{args.run_number}' if args.rm3 == 'y' else f'-{args.run_number}'

run_save_full = f"{run_save_folder}/{args.split}.run"

searcher = LuceneSearcher(os.path.join(args.index_path, args.index_name))
searcher.set_bm25(k1=args.param_k1, b=args.param_b)

if args.rm3 == 'y':
searcher.set_rm3()

# Retrieve
run_result = []

for query_id in tqdm(queries):
for sintetic_query_id in queries[query_id]:
hits = searcher.search(f'{queries[query_id][sintetic_query_id]}', k=args.K)
sintetic_query_results = []
for rank, hit in enumerate(hits, start=1):
sintetic_query_results.append((query_id, 'Q0', hit.docid, rank, hit.score, f'{query_id}_{sintetic_query_id}'))

if sintetic_query_results != []:
run_result.append(TrecRun.from_list(sintetic_query_results))

results = reciprocal_rank_fusion(run_result, depth=args.K, k=args.K)

print(f"saving run to: {run_save_full}")
os.makedirs(os.path.dirname(run_save_full), exist_ok=True)
results.save_to_txt(run_save_full)

if dataset.has_qrels():

with open(run_save_full, 'r') as h:
run_to_eval = pytrec_eval.parse_run(h)

qrel, n_missing = utils.get_qrel(dataset, run_to_eval)
if n_missing > 0:
raise ValueError(f"Number of missing qids in run: {n_missing}")

evaluator = pytrec_eval.RelevanceEvaluator(
qrel, args.metrics.split(","))

eval_res = evaluator.evaluate(run_to_eval)

eval_res_agg = utils.aggregate_pytrec(eval_res, "mean")

for metric, (mean, std) in eval_res_agg.items():
print(f"{metric:<12}: {mean:.4f} ({std:0.4f})")

if __name__ == '__main__':
main()
144 changes: 144 additions & 0 deletions dense_inference_decomp_with_rrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import json
import argparse
import os
import pandas as pd
import logging


from tqdm import tqdm
from typing import Dict
from pyserini.search.lucene import LuceneSearcher
from pyserini.trectools import TrecRun
from pyserini.fusion import reciprocal_rank_fusion
from ranx import Qrels, Run, evaluate

from modules import llm_based_decomposition, sentence_decomposition
import tot
import ir_datasets
from src import utils
import pytrec_eval

METRICS = "P_1,recall_10,recall_100,recall_1000,ndcg_cut_10,ndcg_cut_100,ndcg_cut_1000,recip_rank"

log = logging.getLogger(__name__)

def main():
parser = argparse.ArgumentParser()
# Path to indexes directory
# parser.add_argument("--index_name", default="dense.index", help="name of index")
parser.add_argument("--index_name", default="bm25_0.8_1.0", help="name of index")

parser.add_argument("--decomposition_method", default="llm", help="how to decompose")

parser.add_argument("--data_path", default="/home/ddo/CMU/PLLM/TREC-TOT", help="location to dataset")

parser.add_argument("--split", choices={"train", "dev", "test"}, default="dev", help="split to run")

parser.add_argument("--index_path", default="./dense_index_hnsw/", help="path to store (all) indices")

parser.add_argument("--metrics", default=METRICS, help="csv - metrics to evaluate")

parser.add_argument("--param_k1", default=0.8, type=float, help="param: k1 for BM25")

parser.add_argument("--param_b", default=1.0, type=float, help="param: b for BM25")

# BM25 parameters
parser.add_argument('--K', type=int, help='retrieve top K documents', default=1000)

# Binary flags to enable or disable ranking methods
parser.add_argument('--rm3', type=str, help='enable or disable rm3', choices=['y', 'n'], default='n')

# Run number
parser.add_argument('--run_number', type=int, help='run number', default=1)

# Output options and directory
parser.add_argument('--output_dir', type=str, help='path to output_dir', default="runs_dense/")
parser.add_argument('--load_queries', type=str, help='used previously decomposed queries', choices=['y', 'n'], default='y')
args = parser.parse_args()

tot.register(args.data_path)
args.split = 'train'

irds_name = "trec-tot:" + args.split
dataset = ir_datasets.load(irds_name)
args.load_queries = 'y'
if args.decomposition_method == "llm":
if args.load_queries == 'n':
queries_expanded = llm_based_decomposition(dataset, f"{args.data_path}/decomposed_queries/{args.split}")
else:
queries_expanded = f"{args.data_path}/decomposed_queries/{args.split}/llm_decomposed_queries.json"
else:
if args.load_queries == 'n':
queries_expanded = sentence_decomposition(dataset, f"{args.data_path}/decomposed_queries/{args.split}")
else:
queries_expanded = f"{args.data_path}/decomposed_queries/{args.split}/sentence_decomposed_queries.json"

queries = json.load(open(queries_expanded))


run_save_folder = f'{args.output_dir}DENSE-RRF-{args.split}'
if args.decomposition_method == "llm":
run_save_folder += f'-llm'

run_save_folder += f'-RM3-{args.run_number}' if args.rm3 == 'y' else f'-{args.run_number}'
run_save_folder += args.index_path.split('/')[-1]
run_save_full = f"{run_save_folder}/{args.split}.run"


from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder, AutoQueryEncoder
index_p = args.index_path
# TODO: Put this in a args
print(index_p)
encoder = AutoQueryEncoder('/home/ddo/CMU/PLLM/llms-project/dense_models/baseline_distilbert/model')
faiss_searcher = FaissSearcher(
index_p,
encoder
)
searcher = faiss_searcher

# searcher = LuceneSearcher(os.path.join(args.index_path, args.index_name))
# searcher.set_bm25(k1=args.param_k1, b=args.param_b)

if args.rm3 == 'y':
searcher.set_rm3()

# Retrieve
run_result = []

for jj,query_id in tqdm(enumerate(queries)):
for sintetic_query_id in queries[query_id]:
hits = searcher.search(f'{queries[query_id][sintetic_query_id]}', k=args.K)
sintetic_query_results = []
for rank, hit in enumerate(hits, start=1):
sintetic_query_results.append((query_id, 'Q0', hit.docid, rank, hit.score, f'{query_id}_{sintetic_query_id}'))

if sintetic_query_results != []:
run_result.append(TrecRun.from_list(sintetic_query_results))

results = reciprocal_rank_fusion(run_result, depth=args.K, k=args.K)

print(f"saving run to: {run_save_full}")
os.makedirs(os.path.dirname(run_save_full), exist_ok=True)
results.save_to_txt(run_save_full)

if dataset.has_qrels():

with open(run_save_full, 'r') as h:
run_to_eval = pytrec_eval.parse_run(h)

qrel, n_missing = utils.get_qrel(dataset, run_to_eval)
if n_missing > 0:
raise ValueError(f"Number of missing qids in run: {n_missing}")

evaluator = pytrec_eval.RelevanceEvaluator(
qrel, args.metrics.split(","))

eval_res = evaluator.evaluate(run_to_eval)

eval_res_agg = utils.aggregate_pytrec(eval_res, "mean")

for metric, (mean, std) in eval_res_agg.items():
print(f"{metric:<12}: {mean:.4f} ({std:0.4f})")

if __name__ == '__main__':
main()
Loading