Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
30 changes: 21 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,36 @@ and [video](https://www.youtube.com/watch?v=EVlLqig3qEI).
* Clone the repository:
```git clone git@github.com:PopicLab/cue.git```

#### Setup a Python virtual environment (recommended)
* Move into the directory
```cd cue```

#### Option 1: Setup a Python virtual environment (recommended)

* Create the virtual environment (in the env directory):
```$> python3.7 -m venv env```
```python3.7 -m venv env```

* Activate the environment:
```$> source env/bin/activate```
```source env/bin/activate```

* Install all the required packages in the virtual environment (this should take a few minutes):
```$> pip --no-cache-dir install -r install/requirements.txt```
Packages can also be installed individually using the versions
provided in the ```install/requirements.txt``` file; for example:
```$> pip install numpy==1.18.5```
```pip --no-cache-dir install .```

To deactivate the environment: ```deactivate```

#### Option 2: Conda environment

Alternatively one can create a Conda environment. Install [miniconda] if you don't have Conda.

* Setup Conda environment
```conda create -n cue python=3.7```

* Set the ```PYTHONPATH``` as follows: ```export PYTHONPATH=${PYTHONPATH}:/path/to/cue```
* Activate environment
```conda activate cue```

To deactivate the environment: ```$> deactivate```
* Install cue and all required dependancies
```pip install .```

To deactivate the environment: ```conda deactivate```

#### Download the latest pre-trained Cue model

Expand Down
279 changes: 144 additions & 135 deletions engine/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,140 +44,149 @@
import warnings
warnings.filterwarnings("ignore")

print("*********************************")
print("* cue (%s): discovery mode *" % engine.__version__)
print("*********************************")


# ------ CLI ------
parser = argparse.ArgumentParser(description='SV calling functionality')
parser.add_argument('--data_config', help='Data config')
parser.add_argument('--model_config', help='Trained model config')
parser.add_argument('--refine_config', help='Trained refine model config', default=None)
parser.add_argument('--skip_inference', action='store_true', help='Do not re-run image-based inference', default=False)
args = parser.parse_args()

# load the configs
config = config_utils.load_config(args.model_config, config_type=config_utils.CONFIG_TYPE.TEST)
data_config = config_utils.load_config(args.data_config, config_type=config_utils.CONFIG_TYPE.DATA)
refine_config = None
if args.refine_config is not None:
refine_config = config_utils.load_config(args.refine_config)
given_ground_truth = data_config.bed is not None # (benchmarking mode)


def call(device, chr_names, uid):
# runs SV calling on the specified device for the specified list of chromosomes
# load the pre-trained model on the specified device
model = models.MultiSVHG(config)
model.load_state_dict(torch.load(config.model_path, device))
model.to(device)
logging.root.setLevel(logging.getLevelName(config.logging_level))
logging.info("Loaded model: %s on %s" % (config.model_path, str(device)))

# process each chromosome, loaded as a separate dataset
for chr_name in chr_names:
predictions_dir = "%s/predictions/%s.%s/" % (config.report_dir, uid, chr_name)
Path(predictions_dir).mkdir(parents=True, exist_ok=True)
aln_index = AlnIndex.generate_or_load(chr_name, data_config)
dataset = SVStreamingDataset(data_config, interval_size=interval_size, step_size=step_size, store=False,
include_chrs=[chr_name], allow_empty=True, aln_index=aln_index)
data_loader = DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=False,
collate_fn=datasets.collate_fn)
logging.info("Generating SV predictions for %s" % chr_name)
predictions = core.evaluate(model, data_loader, config, device, output_dir=predictions_dir,
collect_data_metrics=True, given_ground_truth=given_ground_truth)
torch.save(predictions, "%s/predictions.pkl" % predictions_dir)
return True


# ------ Image-based discovery ------
n_procs = len(config.devices)
chr_name_chunks, chr_names = seq_utils.partition_chrs(data_config.chr_names, data_config.fai, n_procs)
logging.info("Running on %d CPUs/GPUs" % n_procs)
logging.info("Chromosome lists processed by each process: " + str(chr_name_chunks))
outputs_per_scan = []
for interval_size in [data_config.interval_size]:
for step_size in [data_config.step_size]:
scan_id = len(outputs_per_scan)
if not args.skip_inference:
_ = Parallel(n_jobs=n_procs)(
delayed(call)(config.devices[i], chr_name_chunks[i], scan_id) for i in range(n_procs))
outputs = []
for chr_name in chr_names:
predictions_dir = "%s/predictions/0.%s/" % (config.report_dir, chr_name)
logging.debug("Loading: ", predictions_dir)
predictions_per_chr = torch.load("%s/predictions.pkl" % predictions_dir)
outputs.extend(predictions_per_chr)
outputs_per_scan.append(outputs)

# ------ Genome-based post-processing ------
chr_index = io.load_faidx(data_config.fai)
candidates_per_scan = []
for outputs in outputs_per_scan:
candidates = []
filtered_candidates = []
for output in outputs:
svs, filtered_svs = utils.img_to_svs(output, data_config, chr_index)
candidates.extend(svs)
filtered_candidates.extend(filtered_svs)
candidates = sv_filters.filter_svs(candidates, data_config.blacklist_bed, filtered_candidates)
candidates_per_scan.append(candidates)

sv_calls = candidates_per_scan[0]
for i in range(1, len(candidates_per_scan)):
sv_calls = sv_filters.merge_sv_candidates(sv_calls, candidates_per_scan[i])

# output candidate SVs (pre-refinement)
candidate_out_bed_file = "%s/candidate_svs.bed" % config.report_dir
io.write_bed(candidate_out_bed_file, sv_calls)
chr2calls = defaultdict(list)
for sv in sv_calls:
chr2calls[sv.intervalA.chr_name].append(sv)
for chr_name in chr_names:
io.write_bed("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name), chr2calls[chr_name])

# ------ NN-aided breakpoint refinement ------
post_process_refined = False
if refine_config is not None and refine_config.pretrained_model is not None:
def refine(device, chr_names):
refinet = models.CueModelConfig(refine_config).get_model()
refinet.load_state_dict(torch.load(refine_config.pretrained_model, refine_config.device))
refinet.to(device)
refinet.eval()
refinery = SVKeypointRefinery(refinet, device, refine_config.padding, refine_config.image_dim)
# Don't use interactive backend for figure generation.
# See https://stackoverflow.com/questions/19518352/tkinter-tclerror-couldnt-connect-to-display-localhost18-0
import matplotlib
matplotlib.use('Agg')

def main():
print("*********************************")
print("* cue (%s): discovery mode *" % engine.__version__)
print("*********************************")


# ------ CLI ------
parser = argparse.ArgumentParser(description='SV calling functionality')
parser.add_argument('--data_config', help='Data config')
parser.add_argument('--model_config', help='Trained model config')
parser.add_argument('--refine_config', help='Trained refine model config', default=None)
parser.add_argument('--skip_inference', action='store_true', help='Do not re-run image-based inference', default=False)
args = parser.parse_args()

# load the configs
config = config_utils.load_config(args.model_config, config_type=config_utils.CONFIG_TYPE.TEST)
data_config = config_utils.load_config(args.data_config, config_type=config_utils.CONFIG_TYPE.DATA)
refine_config = None
if args.refine_config is not None:
refine_config = config_utils.load_config(args.refine_config)
given_ground_truth = data_config.bed is not None # (benchmarking mode)


def call(device, chr_names, uid):
# runs SV calling on the specified device for the specified list of chromosomes
# load the pre-trained model on the specified device
model = models.MultiSVHG(config)
model.load_state_dict(torch.load(config.model_path, device))
model.to(device)
logging.root.setLevel(logging.getLevelName(config.logging_level))
logging.info("Loaded model: %s on %s" % (config.model_path, str(device)))

# process each chromosome, loaded as a separate dataset
for chr_name in chr_names:
refinery.bam_index = AlnIndex.generate_or_load(chr_name, refine_config)
refinery.image_generator = SVStreamingDataset(refine_config, interval_size=None, store=False,
allow_empty=True, aln_index=refinery.bam_index)
chr_calls = io.bed2sv_calls("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name))
for sv_call in chr_calls:
refinery.refine_call(sv_call)
chr_out_bed_file = "%s/refined_svs.%s.bed" % (config.report_dir, chr_name)
io.write_bed(chr_out_bed_file, chr_calls)
Parallel(n_jobs=n_procs)(delayed(refine)(chr_name_chunks[i]) for i in range(n_procs))
post_process_refined = True
elif not data_config.refine_disable: # ------ Genome-based breakpoint refinement ------
def refine(chr_names):
for chr_name in chr_names:
chr_calls = io.bed2sv_calls("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name))
os.remove("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name))
chr_calls = seq.refinery.refine_svs(chr_calls, data_config, chr_index)
io.write_bed("%s/refined_svs.%s.bed" % (config.report_dir, chr_name), chr_calls)
Parallel(n_jobs=n_procs)(delayed(refine)(chr_name_chunks[i]) for i in range(n_procs))
post_process_refined = True


# output candidate SVs (post-refinement)
if post_process_refined:
sv_calls_refined = []
predictions_dir = "%s/predictions/%s.%s/" % (config.report_dir, uid, chr_name)
Path(predictions_dir).mkdir(parents=True, exist_ok=True)
aln_index = AlnIndex.generate_or_load(chr_name, data_config)
dataset = SVStreamingDataset(data_config, interval_size=interval_size, step_size=step_size, store=False,
include_chrs=[chr_name], allow_empty=True, aln_index=aln_index)
data_loader = DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=False,
collate_fn=datasets.collate_fn)
logging.info("Generating SV predictions for %s" % chr_name)
predictions = core.evaluate(model, data_loader, config, device, output_dir=predictions_dir,
collect_data_metrics=True, given_ground_truth=given_ground_truth)
torch.save(predictions, "%s/predictions.pkl" % predictions_dir)
return True


# ------ Image-based discovery ------
n_procs = len(config.devices)
chr_name_chunks, chr_names = seq_utils.partition_chrs(data_config.chr_names, data_config.fai, n_procs)
logging.info("Running on %d CPUs/GPUs" % n_procs)
logging.info("Chromosome lists processed by each process: " + str(chr_name_chunks))
outputs_per_scan = []
for interval_size in [data_config.interval_size]:
for step_size in [data_config.step_size]:
scan_id = len(outputs_per_scan)
if not args.skip_inference:
_ = Parallel(n_jobs=n_procs)(
delayed(call)(config.devices[i], chr_name_chunks[i], scan_id) for i in range(n_procs))
outputs = []
for chr_name in chr_names:
predictions_dir = "%s/predictions/0.%s/" % (config.report_dir, chr_name)
logging.debug("Loading: ", predictions_dir)
predictions_per_chr = torch.load("%s/predictions.pkl" % predictions_dir)
outputs.extend(predictions_per_chr)
outputs_per_scan.append(outputs)

# ------ Genome-based post-processing ------
chr_index = io.load_faidx(data_config.fai)
candidates_per_scan = []
for outputs in outputs_per_scan:
candidates = []
filtered_candidates = []
for output in outputs:
svs, filtered_svs = utils.img_to_svs(output, data_config, chr_index)
candidates.extend(svs)
filtered_candidates.extend(filtered_svs)
candidates = sv_filters.filter_svs(candidates, data_config.blacklist_bed, filtered_candidates)
candidates_per_scan.append(candidates)

sv_calls = candidates_per_scan[0]
for i in range(1, len(candidates_per_scan)):
sv_calls = sv_filters.merge_sv_candidates(sv_calls, candidates_per_scan[i])

# output candidate SVs (pre-refinement)
candidate_out_bed_file = "%s/candidate_svs.bed" % config.report_dir
io.write_bed(candidate_out_bed_file, sv_calls)
chr2calls = defaultdict(list)
for sv in sv_calls:
chr2calls[sv.intervalA.chr_name].append(sv)
for chr_name in chr_names:
sv_calls_refined.extend(io.bed2sv_calls("%s/refined_svs.%s.bed" % (config.report_dir, chr_name)))
os.remove("%s/refined_svs.%s.bed" % (config.report_dir, chr_name))
candidate_out_bed_file = "%s/refined_svs.bed" % config.report_dir
io.write_bed(candidate_out_bed_file, sv_calls_refined)

# ------ IO ------
# write SV calls to file
io.bed2vcf(candidate_out_bed_file, "%s/svs.vcf" % config.report_dir, data_config.fai,
min_score=data_config.min_qual_score, min_len=data_config.min_sv_len)
io.write_bed("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name), chr2calls[chr_name])

# ------ NN-aided breakpoint refinement ------
post_process_refined = False
if refine_config is not None and refine_config.pretrained_model is not None:
def refine(device, chr_names):
refinet = models.CueModelConfig(refine_config).get_model()
refinet.load_state_dict(torch.load(refine_config.pretrained_model, refine_config.device))
refinet.to(device)
refinet.eval()
refinery = SVKeypointRefinery(refinet, device, refine_config.padding, refine_config.image_dim)
for chr_name in chr_names:
refinery.bam_index = AlnIndex.generate_or_load(chr_name, refine_config)
refinery.image_generator = SVStreamingDataset(refine_config, interval_size=None, store=False,
allow_empty=True, aln_index=refinery.bam_index)
chr_calls = io.bed2sv_calls("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name))
for sv_call in chr_calls:
refinery.refine_call(sv_call)
chr_out_bed_file = "%s/refined_svs.%s.bed" % (config.report_dir, chr_name)
io.write_bed(chr_out_bed_file, chr_calls)
Parallel(n_jobs=n_procs)(delayed(refine)(chr_name_chunks[i]) for i in range(n_procs))
post_process_refined = True
elif not data_config.refine_disable: # ------ Genome-based breakpoint refinement ------
def refine(chr_names):
for chr_name in chr_names:
chr_calls = io.bed2sv_calls("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name))
os.remove("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name))
chr_calls = seq.refinery.refine_svs(chr_calls, data_config, chr_index)
io.write_bed("%s/refined_svs.%s.bed" % (config.report_dir, chr_name), chr_calls)
Parallel(n_jobs=n_procs)(delayed(refine)(chr_name_chunks[i]) for i in range(n_procs))
post_process_refined = True


# output candidate SVs (post-refinement)
if post_process_refined:
sv_calls_refined = []
for chr_name in chr_names:
sv_calls_refined.extend(io.bed2sv_calls("%s/refined_svs.%s.bed" % (config.report_dir, chr_name)))
os.remove("%s/refined_svs.%s.bed" % (config.report_dir, chr_name))
candidate_out_bed_file = "%s/refined_svs.bed" % config.report_dir
io.write_bed(candidate_out_bed_file, sv_calls_refined)

# ------ IO ------
# write SV calls to file
io.bed2vcf(candidate_out_bed_file, "%s/svs.vcf" % config.report_dir, data_config.fai,
min_score=data_config.min_qual_score, min_len=data_config.min_sv_len)

if __name__ == "__main__":
main()
Loading