diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bfd11a5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ \ No newline at end of file diff --git a/README.md b/README.md index b7eaf50..5a4b069 100644 --- a/README.md +++ b/README.md @@ -34,24 +34,36 @@ and [video](https://www.youtube.com/watch?v=EVlLqig3qEI). * Clone the repository: ```git clone git@github.com:PopicLab/cue.git``` -#### Setup a Python virtual environment (recommended) +* Move into the directory +```cd cue``` + +#### Option 1: Setup a Python virtual environment (recommended) * Create the virtual environment (in the env directory): -```$> python3.7 -m venv env``` +```python3.7 -m venv env``` * Activate the environment: -```$> source env/bin/activate``` +```source env/bin/activate``` * Install all the required packages in the virtual environment (this should take a few minutes): -```$> pip --no-cache-dir install -r install/requirements.txt``` -Packages can also be installed individually using the versions -provided in the ```install/requirements.txt``` file; for example: -```$> pip install numpy==1.18.5``` +```pip --no-cache-dir install .``` + +To deactivate the environment: ```deactivate``` + +#### Option 2: Conda environment + +Alternatively one can create a Conda environment. Install [miniconda] if you don't have Conda. + +* Setup Conda environment +```conda create -n cue python=3.7``` -* Set the ```PYTHONPATH``` as follows: ```export PYTHONPATH=${PYTHONPATH}:/path/to/cue``` +* Activate environment +```conda activate cue``` -To deactivate the environment: ```$> deactivate``` +* Install cue and all required dependancies +```pip install .``` +To deactivate the environment: ```conda deactivate``` #### Download the latest pre-trained Cue model diff --git a/engine/call.py b/engine/call.py index 73870e8..9a9afc2 100644 --- a/engine/call.py +++ b/engine/call.py @@ -44,140 +44,149 @@ import warnings warnings.filterwarnings("ignore") -print("*********************************") -print("* cue (%s): discovery mode *" % engine.__version__) -print("*********************************") - - -# ------ CLI ------ -parser = argparse.ArgumentParser(description='SV calling functionality') -parser.add_argument('--data_config', help='Data config') -parser.add_argument('--model_config', help='Trained model config') -parser.add_argument('--refine_config', help='Trained refine model config', default=None) -parser.add_argument('--skip_inference', action='store_true', help='Do not re-run image-based inference', default=False) -args = parser.parse_args() - -# load the configs -config = config_utils.load_config(args.model_config, config_type=config_utils.CONFIG_TYPE.TEST) -data_config = config_utils.load_config(args.data_config, config_type=config_utils.CONFIG_TYPE.DATA) -refine_config = None -if args.refine_config is not None: - refine_config = config_utils.load_config(args.refine_config) -given_ground_truth = data_config.bed is not None # (benchmarking mode) - - -def call(device, chr_names, uid): - # runs SV calling on the specified device for the specified list of chromosomes - # load the pre-trained model on the specified device - model = models.MultiSVHG(config) - model.load_state_dict(torch.load(config.model_path, device)) - model.to(device) - logging.root.setLevel(logging.getLevelName(config.logging_level)) - logging.info("Loaded model: %s on %s" % (config.model_path, str(device))) - - # process each chromosome, loaded as a separate dataset - for chr_name in chr_names: - predictions_dir = "%s/predictions/%s.%s/" % (config.report_dir, uid, chr_name) - Path(predictions_dir).mkdir(parents=True, exist_ok=True) - aln_index = AlnIndex.generate_or_load(chr_name, data_config) - dataset = SVStreamingDataset(data_config, interval_size=interval_size, step_size=step_size, store=False, - include_chrs=[chr_name], allow_empty=True, aln_index=aln_index) - data_loader = DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=False, - collate_fn=datasets.collate_fn) - logging.info("Generating SV predictions for %s" % chr_name) - predictions = core.evaluate(model, data_loader, config, device, output_dir=predictions_dir, - collect_data_metrics=True, given_ground_truth=given_ground_truth) - torch.save(predictions, "%s/predictions.pkl" % predictions_dir) - return True - - -# ------ Image-based discovery ------ -n_procs = len(config.devices) -chr_name_chunks, chr_names = seq_utils.partition_chrs(data_config.chr_names, data_config.fai, n_procs) -logging.info("Running on %d CPUs/GPUs" % n_procs) -logging.info("Chromosome lists processed by each process: " + str(chr_name_chunks)) -outputs_per_scan = [] -for interval_size in [data_config.interval_size]: - for step_size in [data_config.step_size]: - scan_id = len(outputs_per_scan) - if not args.skip_inference: - _ = Parallel(n_jobs=n_procs)( - delayed(call)(config.devices[i], chr_name_chunks[i], scan_id) for i in range(n_procs)) - outputs = [] - for chr_name in chr_names: - predictions_dir = "%s/predictions/0.%s/" % (config.report_dir, chr_name) - logging.debug("Loading: ", predictions_dir) - predictions_per_chr = torch.load("%s/predictions.pkl" % predictions_dir) - outputs.extend(predictions_per_chr) - outputs_per_scan.append(outputs) - -# ------ Genome-based post-processing ------ -chr_index = io.load_faidx(data_config.fai) -candidates_per_scan = [] -for outputs in outputs_per_scan: - candidates = [] - filtered_candidates = [] - for output in outputs: - svs, filtered_svs = utils.img_to_svs(output, data_config, chr_index) - candidates.extend(svs) - filtered_candidates.extend(filtered_svs) - candidates = sv_filters.filter_svs(candidates, data_config.blacklist_bed, filtered_candidates) - candidates_per_scan.append(candidates) - -sv_calls = candidates_per_scan[0] -for i in range(1, len(candidates_per_scan)): - sv_calls = sv_filters.merge_sv_candidates(sv_calls, candidates_per_scan[i]) - -# output candidate SVs (pre-refinement) -candidate_out_bed_file = "%s/candidate_svs.bed" % config.report_dir -io.write_bed(candidate_out_bed_file, sv_calls) -chr2calls = defaultdict(list) -for sv in sv_calls: - chr2calls[sv.intervalA.chr_name].append(sv) -for chr_name in chr_names: - io.write_bed("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name), chr2calls[chr_name]) - -# ------ NN-aided breakpoint refinement ------ -post_process_refined = False -if refine_config is not None and refine_config.pretrained_model is not None: - def refine(device, chr_names): - refinet = models.CueModelConfig(refine_config).get_model() - refinet.load_state_dict(torch.load(refine_config.pretrained_model, refine_config.device)) - refinet.to(device) - refinet.eval() - refinery = SVKeypointRefinery(refinet, device, refine_config.padding, refine_config.image_dim) +# Don't use interactive backend for figure generation. +# See https://stackoverflow.com/questions/19518352/tkinter-tclerror-couldnt-connect-to-display-localhost18-0 +import matplotlib +matplotlib.use('Agg') + +def main(): + print("*********************************") + print("* cue (%s): discovery mode *" % engine.__version__) + print("*********************************") + + + # ------ CLI ------ + parser = argparse.ArgumentParser(description='SV calling functionality') + parser.add_argument('--data_config', help='Data config') + parser.add_argument('--model_config', help='Trained model config') + parser.add_argument('--refine_config', help='Trained refine model config', default=None) + parser.add_argument('--skip_inference', action='store_true', help='Do not re-run image-based inference', default=False) + args = parser.parse_args() + + # load the configs + config = config_utils.load_config(args.model_config, config_type=config_utils.CONFIG_TYPE.TEST) + data_config = config_utils.load_config(args.data_config, config_type=config_utils.CONFIG_TYPE.DATA) + refine_config = None + if args.refine_config is not None: + refine_config = config_utils.load_config(args.refine_config) + given_ground_truth = data_config.bed is not None # (benchmarking mode) + + + def call(device, chr_names, uid): + # runs SV calling on the specified device for the specified list of chromosomes + # load the pre-trained model on the specified device + model = models.MultiSVHG(config) + model.load_state_dict(torch.load(config.model_path, device)) + model.to(device) + logging.root.setLevel(logging.getLevelName(config.logging_level)) + logging.info("Loaded model: %s on %s" % (config.model_path, str(device))) + + # process each chromosome, loaded as a separate dataset for chr_name in chr_names: - refinery.bam_index = AlnIndex.generate_or_load(chr_name, refine_config) - refinery.image_generator = SVStreamingDataset(refine_config, interval_size=None, store=False, - allow_empty=True, aln_index=refinery.bam_index) - chr_calls = io.bed2sv_calls("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name)) - for sv_call in chr_calls: - refinery.refine_call(sv_call) - chr_out_bed_file = "%s/refined_svs.%s.bed" % (config.report_dir, chr_name) - io.write_bed(chr_out_bed_file, chr_calls) - Parallel(n_jobs=n_procs)(delayed(refine)(chr_name_chunks[i]) for i in range(n_procs)) - post_process_refined = True -elif not data_config.refine_disable: # ------ Genome-based breakpoint refinement ------ - def refine(chr_names): - for chr_name in chr_names: - chr_calls = io.bed2sv_calls("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name)) - os.remove("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name)) - chr_calls = seq.refinery.refine_svs(chr_calls, data_config, chr_index) - io.write_bed("%s/refined_svs.%s.bed" % (config.report_dir, chr_name), chr_calls) - Parallel(n_jobs=n_procs)(delayed(refine)(chr_name_chunks[i]) for i in range(n_procs)) - post_process_refined = True - - -# output candidate SVs (post-refinement) -if post_process_refined: - sv_calls_refined = [] + predictions_dir = "%s/predictions/%s.%s/" % (config.report_dir, uid, chr_name) + Path(predictions_dir).mkdir(parents=True, exist_ok=True) + aln_index = AlnIndex.generate_or_load(chr_name, data_config) + dataset = SVStreamingDataset(data_config, interval_size=interval_size, step_size=step_size, store=False, + include_chrs=[chr_name], allow_empty=True, aln_index=aln_index) + data_loader = DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=False, + collate_fn=datasets.collate_fn) + logging.info("Generating SV predictions for %s" % chr_name) + predictions = core.evaluate(model, data_loader, config, device, output_dir=predictions_dir, + collect_data_metrics=True, given_ground_truth=given_ground_truth) + torch.save(predictions, "%s/predictions.pkl" % predictions_dir) + return True + + + # ------ Image-based discovery ------ + n_procs = len(config.devices) + chr_name_chunks, chr_names = seq_utils.partition_chrs(data_config.chr_names, data_config.fai, n_procs) + logging.info("Running on %d CPUs/GPUs" % n_procs) + logging.info("Chromosome lists processed by each process: " + str(chr_name_chunks)) + outputs_per_scan = [] + for interval_size in [data_config.interval_size]: + for step_size in [data_config.step_size]: + scan_id = len(outputs_per_scan) + if not args.skip_inference: + _ = Parallel(n_jobs=n_procs)( + delayed(call)(config.devices[i], chr_name_chunks[i], scan_id) for i in range(n_procs)) + outputs = [] + for chr_name in chr_names: + predictions_dir = "%s/predictions/0.%s/" % (config.report_dir, chr_name) + logging.debug("Loading: ", predictions_dir) + predictions_per_chr = torch.load("%s/predictions.pkl" % predictions_dir) + outputs.extend(predictions_per_chr) + outputs_per_scan.append(outputs) + + # ------ Genome-based post-processing ------ + chr_index = io.load_faidx(data_config.fai) + candidates_per_scan = [] + for outputs in outputs_per_scan: + candidates = [] + filtered_candidates = [] + for output in outputs: + svs, filtered_svs = utils.img_to_svs(output, data_config, chr_index) + candidates.extend(svs) + filtered_candidates.extend(filtered_svs) + candidates = sv_filters.filter_svs(candidates, data_config.blacklist_bed, filtered_candidates) + candidates_per_scan.append(candidates) + + sv_calls = candidates_per_scan[0] + for i in range(1, len(candidates_per_scan)): + sv_calls = sv_filters.merge_sv_candidates(sv_calls, candidates_per_scan[i]) + + # output candidate SVs (pre-refinement) + candidate_out_bed_file = "%s/candidate_svs.bed" % config.report_dir + io.write_bed(candidate_out_bed_file, sv_calls) + chr2calls = defaultdict(list) + for sv in sv_calls: + chr2calls[sv.intervalA.chr_name].append(sv) for chr_name in chr_names: - sv_calls_refined.extend(io.bed2sv_calls("%s/refined_svs.%s.bed" % (config.report_dir, chr_name))) - os.remove("%s/refined_svs.%s.bed" % (config.report_dir, chr_name)) - candidate_out_bed_file = "%s/refined_svs.bed" % config.report_dir - io.write_bed(candidate_out_bed_file, sv_calls_refined) - -# ------ IO ------ -# write SV calls to file -io.bed2vcf(candidate_out_bed_file, "%s/svs.vcf" % config.report_dir, data_config.fai, - min_score=data_config.min_qual_score, min_len=data_config.min_sv_len) + io.write_bed("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name), chr2calls[chr_name]) + + # ------ NN-aided breakpoint refinement ------ + post_process_refined = False + if refine_config is not None and refine_config.pretrained_model is not None: + def refine(device, chr_names): + refinet = models.CueModelConfig(refine_config).get_model() + refinet.load_state_dict(torch.load(refine_config.pretrained_model, refine_config.device)) + refinet.to(device) + refinet.eval() + refinery = SVKeypointRefinery(refinet, device, refine_config.padding, refine_config.image_dim) + for chr_name in chr_names: + refinery.bam_index = AlnIndex.generate_or_load(chr_name, refine_config) + refinery.image_generator = SVStreamingDataset(refine_config, interval_size=None, store=False, + allow_empty=True, aln_index=refinery.bam_index) + chr_calls = io.bed2sv_calls("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name)) + for sv_call in chr_calls: + refinery.refine_call(sv_call) + chr_out_bed_file = "%s/refined_svs.%s.bed" % (config.report_dir, chr_name) + io.write_bed(chr_out_bed_file, chr_calls) + Parallel(n_jobs=n_procs)(delayed(refine)(chr_name_chunks[i]) for i in range(n_procs)) + post_process_refined = True + elif not data_config.refine_disable: # ------ Genome-based breakpoint refinement ------ + def refine(chr_names): + for chr_name in chr_names: + chr_calls = io.bed2sv_calls("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name)) + os.remove("%s/candidate_svs.%s.bed" % (config.report_dir, chr_name)) + chr_calls = seq.refinery.refine_svs(chr_calls, data_config, chr_index) + io.write_bed("%s/refined_svs.%s.bed" % (config.report_dir, chr_name), chr_calls) + Parallel(n_jobs=n_procs)(delayed(refine)(chr_name_chunks[i]) for i in range(n_procs)) + post_process_refined = True + + + # output candidate SVs (post-refinement) + if post_process_refined: + sv_calls_refined = [] + for chr_name in chr_names: + sv_calls_refined.extend(io.bed2sv_calls("%s/refined_svs.%s.bed" % (config.report_dir, chr_name))) + os.remove("%s/refined_svs.%s.bed" % (config.report_dir, chr_name)) + candidate_out_bed_file = "%s/refined_svs.bed" % config.report_dir + io.write_bed(candidate_out_bed_file, sv_calls_refined) + + # ------ IO ------ + # write SV calls to file + io.bed2vcf(candidate_out_bed_file, "%s/svs.vcf" % config.report_dir, data_config.fai, + min_score=data_config.min_qual_score, min_len=data_config.min_sv_len) + +if __name__ == "__main__": + main() diff --git a/engine/generate.py b/engine/generate.py index 74b5e22..c10c2f7 100644 --- a/engine/generate.py +++ b/engine/generate.py @@ -33,46 +33,50 @@ import os import torch -print("*********************************") -print("* cue (%s): image-gen mode *" % engine.__version__) -print("*********************************") -# ------ CLI ------ -parser = argparse.ArgumentParser(description='Generate an SV image dataset') -parser.add_argument('--config', help='Dataset config') -args = parser.parse_args() -# ----------------- +def main(): + print("*********************************") + print("* cue (%s): image-gen mode *" % engine.__version__) + print("*********************************") + # ------ CLI ------ + parser = argparse.ArgumentParser(description='Generate an SV image dataset') + parser.add_argument('--config', help='Dataset config') + args = parser.parse_args() + # ----------------- -def generate(chr_names): - logging.root.setLevel(logging.INFO) - # generates images/annotations for the specified list of chromosomes - for chr_name in chr_names: - aln_index = AlnIndex.generate_or_load(chr_name, config) - dataset = datasets.SVStreamingDataset(config, config.interval_size, config.step_size, - allow_empty=config.allow_empty, - store=config.store_img, include_chrs=[chr_name], aln_index=aln_index, - remove_annotation=config.empty_annotation) - chr_stats = DatasetStats("%s/%s" % (config.info_dir, chr_name), classes=config.classes) - for _, target in dataset: - chr_stats.update(target) - chr_stats.report() - return True + def generate(chr_names): + logging.root.setLevel(logging.INFO) + # generates images/annotations for the specified list of chromosomes + for chr_name in chr_names: + aln_index = AlnIndex.generate_or_load(chr_name, config) + dataset = datasets.SVStreamingDataset(config, config.interval_size, config.step_size, + allow_empty=config.allow_empty, + store=config.store_img, include_chrs=[chr_name], aln_index=aln_index, + remove_annotation=config.empty_annotation) + chr_stats = DatasetStats("%s/%s" % (config.info_dir, chr_name), classes=config.classes) + for _, target in dataset: + chr_stats.update(target) + chr_stats.report() + return True -config = config_utils.load_config(args.config, config_type=config_utils.CONFIG_TYPE.DATA) -chr_name_chunks, _ = utils.partition_chrs(config.chr_names, config.fai, config.n_cpus) -logging.info("Running on %d CPUs" % config.n_cpus) -logging.info("Chromosome lists processed by each process: " + str(chr_name_chunks)) -_ = Parallel(n_jobs=config.n_cpus)( - delayed(generate)(chr_name_chunks[i]) for i in range(config.n_cpus)) -# generate stats for the entire dataset -stats = DatasetStats("%s/%s" % (config.info_dir, "full"), classes=config.classes) -targets = list(os.listdir(config.annotation_dir)) -for target_fname in targets: - target_path = os.path.join(config.annotation_dir, target_fname) - target = torch.load(target_path) - stats.update(target) -stats.report() + config = config_utils.load_config(args.config, config_type=config_utils.CONFIG_TYPE.DATA) + chr_name_chunks, _ = utils.partition_chrs(config.chr_names, config.fai, config.n_cpus) + logging.info("Running on %d CPUs" % config.n_cpus) + logging.info("Chromosome lists processed by each process: " + str(chr_name_chunks)) + _ = Parallel(n_jobs=config.n_cpus)( + delayed(generate)(chr_name_chunks[i]) for i in range(config.n_cpus)) + # generate stats for the entire dataset + stats = DatasetStats("%s/%s" % (config.info_dir, "full"), classes=config.classes) + targets = list(os.listdir(config.annotation_dir)) + for target_fname in targets: + target_path = os.path.join(config.annotation_dir, target_fname) + target = torch.load(target_path) + stats.update(target) + stats.report() + +if __name__ == "__main__": + main() diff --git a/engine/train.py b/engine/train.py index 6baff90..3500a14 100644 --- a/engine/train.py +++ b/engine/train.py @@ -35,78 +35,82 @@ torch.manual_seed(0) +def main(): + print("*********************************") + print("* cue (%s): training mode *" % engine.__version__) + print("*********************************") -print("*********************************") -print("* cue (%s): training mode *" % engine.__version__) -print("*********************************") + # ------ CLI ------ + parser = argparse.ArgumentParser(description='Cue model training') + parser.add_argument('--config', help='Training config') + parser.add_argument('--data_config', help='(Optional) Dataset config for streaming', default=None) + args = parser.parse_args() + # ----------------- -# ------ CLI ------ -parser = argparse.ArgumentParser(description='Cue model training') -parser.add_argument('--config', help='Training config') -parser.add_argument('--data_config', help='(Optional) Dataset config for streaming', default=None) -args = parser.parse_args() -# ----------------- + # ------ Initialization ------ + # load the model configs / setup the experiment + config = config_utils.load_config(args.config) + PHASES = Enum('PHASES', 'TRAIN VALIDATE') -# ------ Initialization ------ -# load the model configs / setup the experiment -config = config_utils.load_config(args.config) -PHASES = Enum('PHASES', 'TRAIN VALIDATE') + # ---------Training dataset-------- + streaming = False + if args.data_config is None: + # static (pre-generated data) + input_datasets = [] + for dataset_id in range(len(config.dataset_dirs)): + input_datasets.append(datasets.SVStaticDataset(config.image_dirs[dataset_id], + config.annotation_dirs[dataset_id], + config.image_dim, config.signal_set, + constants.SV_SIGNAL_SET[config.signal_set_origin], + dataset_id)) + dataset = torch.utils.data.ConcatDataset(input_datasets) + validation_size = int(config.validation_ratio * len(dataset)) + train_size = len(dataset) - validation_size + train_data, validation_data = random_split(dataset, [train_size, validation_size]) + images, targets = next(iter(DataLoader(dataset=dataset, batch_size=min(len(dataset), 400), shuffle=True, + collate_fn=datasets.collate_fn))) + else: + # streaming (on-the-fly data generation) + # data divided into training / validation using chromosomes (since dataset length unknown) + # data_config.chr_names defines the split (exclude/include) + streaming = True + data_config = config_utils.load_config(args.data_config, config_type=config_utils.CONFIG_TYPE.DATA) + train_data = datasets.SVStreamingDataset(data_config, interval_size=data_config.interval_size[0], + step_size=data_config.step_size[0], + exclude_chrs=data_config.chr_names, + store=True, allow_empty=False) + validation_data = datasets.SVStreamingDataset(data_config, interval_size=data_config.interval_size[0], + step_size=data_config.step_size[0], + include_chrs=data_config.chr_names, + store=True, allow_empty=False) -# ---------Training dataset-------- -streaming = False -if args.data_config is None: - # static (pre-generated data) - input_datasets = [] - for dataset_id in range(len(config.dataset_dirs)): - input_datasets.append(datasets.SVStaticDataset(config.image_dirs[dataset_id], - config.annotation_dirs[dataset_id], - config.image_dim, config.signal_set, - constants.SV_SIGNAL_SET[config.signal_set_origin], - dataset_id)) - dataset = torch.utils.data.ConcatDataset(input_datasets) - validation_size = int(config.validation_ratio * len(dataset)) - train_size = len(dataset) - validation_size - train_data, validation_data = random_split(dataset, [train_size, validation_size]) - images, targets = next(iter(DataLoader(dataset=dataset, batch_size=min(len(dataset), 400), shuffle=True, - collate_fn=datasets.collate_fn))) -else: - # streaming (on-the-fly data generation) - # data divided into training / validation using chromosomes (since dataset length unknown) - # data_config.chr_names defines the split (exclude/include) - streaming = True - data_config = config_utils.load_config(args.data_config, config_type=config_utils.CONFIG_TYPE.DATA) - train_data = datasets.SVStreamingDataset(data_config, interval_size=data_config.interval_size[0], - step_size=data_config.step_size[0], - exclude_chrs=data_config.chr_names, - store=True, allow_empty=False) - validation_data = datasets.SVStreamingDataset(data_config, interval_size=data_config.interval_size[0], - step_size=data_config.step_size[0], - include_chrs=data_config.chr_names, - store=True, allow_empty=False) + # ---------Data loaders-------- + data_loaders = {PHASES.TRAIN: DataLoader(dataset=train_data, batch_size=config.batch_size, shuffle=True, + collate_fn=datasets.collate_fn), + PHASES.VALIDATE: DataLoader(dataset=validation_data, batch_size=config.batch_size, shuffle=False, + collate_fn=datasets.collate_fn)} + logging.info("Size of train set: %d; validation set: %d" % (len(data_loaders[PHASES.TRAIN]), + len(data_loaders[PHASES.VALIDATE]))) -# ---------Data loaders-------- -data_loaders = {PHASES.TRAIN: DataLoader(dataset=train_data, batch_size=config.batch_size, shuffle=True, - collate_fn=datasets.collate_fn), - PHASES.VALIDATE: DataLoader(dataset=validation_data, batch_size=config.batch_size, shuffle=False, - collate_fn=datasets.collate_fn)} -logging.info("Size of train set: %d; validation set: %d" % (len(data_loaders[PHASES.TRAIN]), - len(data_loaders[PHASES.VALIDATE]))) + # ---------Model-------- + model = models.CueModelConfig(config).get_model() + if config.pretrained_model is not None: + model.load_state_dict(torch.load(config.pretrained_model, config.device)) + optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.learning_rate_decay_interval, + gamma=config.learning_rate_decay_factor) + model.to(config.device) -# ---------Model-------- -model = models.CueModelConfig(config).get_model() -if config.pretrained_model is not None: - model.load_state_dict(torch.load(config.pretrained_model, config.device)) -optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) -lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.learning_rate_decay_interval, - gamma=config.learning_rate_decay_factor) -model.to(config.device) + # ------ Training ------ + for epoch in range(config.num_epochs): + core.train(model, optimizer, data_loaders[PHASES.TRAIN], config, epoch, collect_data_metrics=(epoch == 0)) + torch.save(model.state_dict(), "%s.epoch%d" % (config.model_path, epoch)) + core.evaluate(model, data_loaders[PHASES.VALIDATE], config, config.device, config.epoch_dirs[epoch], + collect_data_metrics=(epoch == 0), given_ground_truth=True, filters=False) + lr_scheduler.step() -# ------ Training ------ -for epoch in range(config.num_epochs): - core.train(model, optimizer, data_loaders[PHASES.TRAIN], config, epoch, collect_data_metrics=(epoch == 0)) - torch.save(model.state_dict(), "%s.epoch%d" % (config.model_path, epoch)) - core.evaluate(model, data_loaders[PHASES.VALIDATE], config, config.device, config.epoch_dirs[epoch], - collect_data_metrics=(epoch == 0), given_ground_truth=True, filters=False) - lr_scheduler.step() + torch.save(model.state_dict(), config.model_path) -torch.save(model.state_dict(), config.model_path) + +if __name__ == "__main__": + main() diff --git a/engine/view.py b/engine/view.py index 2b55c0c..c60000a 100644 --- a/engine/view.py +++ b/engine/view.py @@ -30,24 +30,28 @@ import warnings warnings.filterwarnings("ignore") +def main(): + # ------ CLI ------ + parser = argparse.ArgumentParser(description='View an SV callset') + parser.add_argument('--config', help='Dataset config') + args = parser.parse_args() + # ----------------- -# ------ CLI ------ -parser = argparse.ArgumentParser(description='View an SV callset') -parser.add_argument('--config', help='Dataset config') -args = parser.parse_args() -# ----------------- + def view(chr_names): + for chr_name in chr_names: + aln_index = AlnIndex.generate_or_load(chr_name, config) + logging.info("Generating SV images for %s" % chr_name) + dataset = datasets.SVBedScanner(config, config.interval_size, allow_empty=False, store=True, + include_chrs=[chr_name], aln_index=aln_index) + for _, target in dataset: + continue + return True -def view(chr_names): - for chr_name in chr_names: - aln_index = AlnIndex.generate_or_load(chr_name, config) - logging.info("Generating SV images for %s" % chr_name) - dataset = datasets.SVBedScanner(config, config.interval_size, allow_empty=False, store=True, - include_chrs=[chr_name], aln_index=aln_index) - for _, target in dataset: - continue - return True + config = config_utils.load_config(args.config, config_type=config_utils.CONFIG_TYPE.DATA) + chr_name_chunks, _ = utils.partition_chrs(config.chr_names, config.fai, config.n_cpus) + _ = Parallel(n_jobs=config.n_cpus)(delayed(view)(chr_name_chunks[i]) for i in range(config.n_cpus)) -config = config_utils.load_config(args.config, config_type=config_utils.CONFIG_TYPE.DATA) -chr_name_chunks, _ = utils.partition_chrs(config.chr_names, config.fai, config.n_cpus) -_ = Parallel(n_jobs=config.n_cpus)(delayed(view)(chr_name_chunks[i]) for i in range(config.n_cpus)) + +if __name__ == "__main__": + main() diff --git a/img/data_metrics.py b/img/data_metrics.py index 039b6bc..8a6bbb3 100644 --- a/img/data_metrics.py +++ b/img/data_metrics.py @@ -2,6 +2,8 @@ from collections import Counter import logging import seaborn as sns +import matplotlib +matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import matplotlib.axes diff --git a/img/plotting.py b/img/plotting.py index 67a4578..76de9cd 100644 --- a/img/plotting.py +++ b/img/plotting.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import matplotlib +matplotlib.use('Agg') from img import constants from img.constants import TargetType @@ -35,7 +37,6 @@ from scipy.ndimage.filters import maximum_filter, gaussian_filter from scipy.ndimage.morphology import generate_binary_structure - def heatmap_np(overlap2D, img_size=1000, vmin=0, vmax=100, cvresize=False): overlap2D = overlap2D.transpose() overlap2D = np.flip(overlap2D, 0) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ddb117b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools==58.0.0"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..955eb29 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,47 @@ +[metadata] +name = cue +author = Victoria Popic +url = https://github.com/PopicLab/cue +licence = MIT +description = A deep learning framework for SV calling and genotyping +long_description = file: README.md +long_description_content_type = text/markdown +version = attr: engine.__version__ + +[options] +requires_python >=3.7 +install_requires = + bitarray ==1.6.3 + cachetools ==4.1.0 + intervaltree ==3.1.0 + cython ==0.29.21 + joblib ==0.16.0 + matplotlib ==3.2.1 + numpy ==1.18.5 + opencv-python ==4.5.1.48 + pandas ==1.0.5 + pycocotools ==2.0.4 + pyfaidx ==0.5.9.5 + pysam ==0.16.0.1 + pytabix ==0.1 + python-dateutil ==2.8.1 + pyyaml ==5.3.1 + seaborn ==0.11.0 + torch ==1.5.1 + torchvision ==0.6.1 +packages = find: + +[options.packages.find] +include = + engine + img + models + seq + utils + +[options.entry_points] +console_scripts = + call.py = engine.call:main + generate.py = engine.generate:main + train.py = engine.train:main + view.py = engine.view:main