-
Notifications
You must be signed in to change notification settings - Fork 94
Expand file tree
/
Copy pathbuild_bm25_index.py
More file actions
126 lines (109 loc) · 4.86 KB
/
build_bm25_index.py
File metadata and controls
126 lines (109 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import json
import os
import pickle
import time
from pathlib import Path
import subprocess
import torch.multiprocessing as mp
import os.path as osp
from datasets import load_dataset
# from dependency_graph.build_graph import build_graph, VERSION
from util.benchmark.setup_repo import setup_repo
from plugins.location_tools.retriever.bm25_retriever import (
build_code_retriever_from_repo as build_code_retriever
)
def list_folders(path):
return [p.name for p in Path(path).iterdir() if p.is_dir()]
def run(rank, repo_queue, repo_path, out_path,
download_repo=False, instance_data=None, similarity_top_k=10):
while True:
try:
repo_name = repo_queue.get_nowait()
except Exception:
# Queue is empty
break
output_file = osp.join(out_path, repo_name)
if osp.exists(output_file):
# print(f'[{rank}] {repo_name} already processed, skipping.')
continue
if download_repo:
# get process specific base dir
repo_base_dir = str(osp.join(repo_path, str(rank)))
os.makedirs(repo_base_dir, exist_ok=True)
# clone and check actual repo
try:
repo_dir = setup_repo(instance_data=instance_data[repo_name],
repo_base_dir=repo_base_dir,
dataset=None)
except subprocess.CalledProcessError as e:
print(f'[{rank}] Error checkout commit {repo_name}: {e}')
continue
else:
repo_dir = osp.join(repo_path, repo_name)
print(f'[{rank}] Start process {repo_name}')
try:
retriever = build_code_retriever(repo_dir, persist_path=output_file,
similarity_top_k=similarity_top_k)
# G = build_graph(repo_dir, global_import=True)
# with open(output_file, 'wb') as f:
# pickle.dump(G, f)
print(f'[{rank}] Processed {repo_name}')
except Exception as e:
print(f'[{rank}] Error processing {repo_name}: {e}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="czlll/SWE-bench_Lite")
parser.add_argument("--split", type=str, default="test")
parser.add_argument('--num_processes', type=int, default=30)
parser.add_argument('--download_repo', action='store_true',
help='Whether to download the codebase to `repo_path` before indexing.')
parser.add_argument('--repo_path', type=str, default='playground/build_graph',
help='The directory where you plan to pull or have already pulled the codebase.')
parser.add_argument('--index_dir', type=str, default='index_data',
help='The base directory where the generated graph index will be saved.')
parser.add_argument('--instance_id_path', type=str, default='',
help='Path to a file containing a list of selected instance IDs.')
args = parser.parse_args()
dataset_name = args.dataset.split('/')[-1]
args.index_dir = f'{args.index_dir}/{dataset_name}/BM25_index/'
os.makedirs(args.index_dir, exist_ok=True)
# load selected repo instance id and instance_data
if args.download_repo:
selected_instance_data = {}
bench_data = load_dataset(args.dataset, split=args.split)
if args.instance_id_path and osp.exists(args.instance_id_path):
with open(args.instance_id_path, 'r') as f:
repo_folders = json.loads(f.read())
for instance in bench_data:
if instance['instance_id'] in repo_folders:
selected_instance_data[instance['instance_id']] = instance
else:
repo_folders = []
for instance in bench_data:
repo_folders.append(instance['instance_id'])
selected_instance_data[instance['instance_id']] = instance
else:
if args.instance_id_path and osp.exists(args.instance_id_path):
with open(args.instance_id_path, 'r') as f:
repo_folders = json.loads(f.read())
else:
repo_folders = list_folders(args.repo_path)
selected_instance_data = None
os.makedirs(args.repo_path, exist_ok=True)
# Create a shared queue and add repositories to it
manager = mp.Manager()
queue = manager.Queue()
for repo in repo_folders:
queue.put(repo)
start_time = time.time()
# Start multiprocessing with a global queue
mp.spawn(
run,
nprocs=args.num_processes,
args=(queue, args.repo_path, args.index_dir,
args.download_repo, selected_instance_data),
join=True
)
end_time = time.time()
print(f'Total Execution time = {end_time - start_time:.3f}s')