Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sierralocal/hivdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class HIVdb():
webserver, to retrieve the rules-based prediction algorithm as ASI XML,
and convert this information into Python objects.
"""
def __init__(self, asi2=None, apobec=None, forceupdate=False):
def __init__(self, asi2=None, apobec=None, forceupdate=False, updater_outdir=None):
self.xml_filename = None
self.json_filename = None

if forceupdate:
import sierralocal.updater as updater
self.xml_filename = updater.update_HIVDB()
self.json_filename = updater.update_APOBEC()
self.xml_filename = updater.update_hivdb(updater_outdir)
self.json_filename = updater.update_apobec_mutation(updater_outdir)
else:
self.set_hivdb_xml(asi2)
self.set_apobec_json(apobec)
Expand Down
47 changes: 42 additions & 5 deletions sierralocal/jsonwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class JSONWriter():
def __init__(self, algorithm):
def __init__(self, algorithm, apobec_csv, unusual_csv, sdrms_csv, mutation_csv):
# possible alternative drug abbrvs
self.names = {'3TC': 'LMV'}

Expand Down Expand Up @@ -39,7 +39,16 @@ def __init__(self, algorithm):
self.rt_comments = dict(csv.reader(rt_file, delimiter='\t'))

# make dictionary for isUnusual
dest = str(Path(os.path.dirname(__file__)) / 'data' / 'rx-all_subtype-all.csv')
if unusual_csv is None:
dest = str(Path(os.path.dirname(__file__)) / 'data' / 'rx-all_subtype-all.csv')
else:
if os.path.isfile(unusual_csv): # Ensure is a file
dest = unusual_csv
else:
raise FileNotFoundError(
"Path to CSV file to determine if is unusual cannot be found at user specified "
"path {}".format(unusual_csv))
print("Using unusual file: "+dest)
with open(dest, 'r', encoding='utf-8-sig') as is_unusual_file:
is_unusual_file = csv.DictReader(is_unusual_file)
self.is_unusual_dic = {}
Expand All @@ -54,7 +63,16 @@ def __init__(self, algorithm):
self.is_unusual_dic[gene].update({pos: {}})
self.is_unusual_dic[gene][pos].update({aa: unusual})

dest = str(Path(os.path.dirname(__file__)) / 'data' / 'sdrms_hiv1.csv')
if sdrms_csv is None:
dest = str(Path(os.path.dirname(__file__)) / 'data' / 'sdrms_hiv1.csv')
else:
if os.path.isfile(sdrms_csv): # Ensure is a file
dest = sdrms_csv
else:
raise FileNotFoundError(
"Path to CSV file to determine SDRM mutations cannot be found at user specified "
"path {}".format(sdrms_csv))
print("Using SDRM mutations file: "+dest)
with open(dest, 'r', encoding='utf-8-sig') as sdrm_files:
sdrm_files = csv.DictReader(sdrm_files)
self.sdrm_dic = {}
Expand Down Expand Up @@ -86,7 +104,17 @@ def __init__(self, algorithm):
self.apobec_drm_dic[gene][position] += aa

# make dictionary for primary type
dest = str(Path(os.path.dirname(__file__)) / 'data' / 'mutation-type-pairs_hiv1.csv')
if mutation_csv is None:
dest = str(Path(os.path.dirname(__file__)) / 'data' / 'mutation-type-pairs_hiv1.csv')
else:
if os.path.isfile(mutation_csv): # Ensure is a file
dest = mutation_csv
else:
raise FileNotFoundError(
"Path to CSV file to determine mutation type cannot be found at user specified "
"path {}".format(mutation_csv))

print("Using mutation type file: "+dest)
with open(dest, 'r', encoding='utf-8-sig') as mut_type_pairs1_files:
mut_type_pairs1_files = csv.DictReader(mut_type_pairs1_files)
self.primary_type_dic = {}
Expand All @@ -102,7 +130,16 @@ def __init__(self, algorithm):
self.primary_type_dic[gene][pos].update({aa: mut})

# make dictionary for apobec mutations
dest = str(Path(os.path.dirname(__file__)) / 'data' / 'apobecs.csv')
if apobec_csv is None:
dest = str(Path(os.path.dirname(__file__)) / 'data' / 'apobecs.csv')
else:
if os.path.isfile(apobec_csv): # Ensure is a file
dest = apobec_csv
else:
raise FileNotFoundError(
"Path to CSV file with APOBEC cannot be found at user specified "
"path {}".format(apobec_csv))
print("Using APOBEC file: "+dest)
with open(dest, 'r', encoding='utf-8-sig') as apobec_mutations:
apobec_mutations = csv.DictReader(apobec_mutations)
self.apobec_mutations_dic = {}
Expand Down
37 changes: 32 additions & 5 deletions sierralocal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import time
import argparse
import json
from pathlib import Path

from sierralocal import score_alg
from sierralocal.hivdb import HIVdb
from sierralocal.jsonwriter import JSONWriter
from sierralocal.nucaminohook import NucAminoAligner


def score(filename, xml_path=None, tsv_path=None, forceupdate=False, do_subtype=False, program='post'): # pragma: no cover
"""
Functionality as a Python module. Can import this function from sierralocal.
Expand Down Expand Up @@ -123,7 +123,8 @@ def scorefile(input_file, algorithm, do_subtype=False, program='post'):
file_genes, sequence_lengths, file_trims, subtypes, na_sequence, ambiguous, gene_order

def sierralocal(fasta, outfile, xml=None, json=None, cleanup=False, forceupdate=False,
program='post', do_subtype=False): # pragma: no cover
apobec_csv=None, unusual_csv=None, sdrms_csv=None, mutation_csv=None,
updater_outdir=None, program='post', do_subtype=False): # pragma: no cover
"""
Contains all initializing and processing calls.

Expand All @@ -134,13 +135,17 @@ def sierralocal(fasta, outfile, xml=None, json=None, cleanup=False, forceupdate=
@param json: <optional> str, path to local copy of HIVdb algorithm APOBEC DRM file
@param cleanup: <optional> bool, to delete alignment file
@param forceupdate: <optional> bool, forces sierralocal to update its local copy of the HIVdb algorithm
@param apobec_csv: str <optional>, Path to CSV APOBEC csv file (default: apobecs.csv)
@param unusual_csv: str <optional>, Path to CSV file to determine if is unusual (default: rx-all_subtype-all.csv)
@param sdrms_csv: str <optional>, Path to CSV file to determine SDRM mutations (default: sdrms_hiv1.csv)
@param mutation_csv: str <optional>, Path to CSV file to determine mutation type (default: mutation-type-pairs_hiv1.csv)
@return: tuple, a tuple of (number of records processed, time elapsed initializing algorithm)
"""

# initialize algorithm and jsonwriter
time0 = time.time()
algorithm = HIVdb(asi2=xml, apobec=json, forceupdate=forceupdate)
writer = JSONWriter(algorithm)
algorithm = HIVdb(asi2=xml, apobec=json, forceupdate=forceupdate, updater_outdir=updater_outdir)
writer = JSONWriter(algorithm, apobec_csv, unusual_csv, sdrms_csv, mutation_csv)
time_elapsed = time.time() - time0

# accommodate single file path argument
Expand Down Expand Up @@ -197,16 +202,36 @@ def parse_args(): # pragma: no cover
help='Forces update of HIVdb algorithm. Requires network connection.')
parser.add_argument('-alignment', default='post', choices=['post', 'nuc'],
help='Alignment program to use, "post" for post align and "nuc" for nucamino')
parser.add_argument('-apobec_csv', default=None,
help='<optional> Path to CSV APOBEC csv file (default: apobecs.csv)')
parser.add_argument('-unusual_csv', default=None,
help='<optional> Path to CSV file to determine if is unusual (default: rx-all_subtype-all.csv)')
parser.add_argument('-sdrms_csv', default=None,
help='<optional> Path to CSV file to determine SDRM mutations (default: sdrms_hiv1.csv)')
parser.add_argument('-mutation_csv', default=None,
help='<optional> Path to CSV file to determine mutation type (default: mutation-type-pairs_hiv1.csv)')
parser.add_argument('-updater_outdir', default=None,
help='<optional> Path to folder to store updated files from updater (default: sierralocal/data folder))')

args = parser.parse_args()
return args


def main(): # pragma: no cover
"""
Main function called from CLI.
"""
args = parse_args()

mod_path = Path(os.path.dirname(__file__))

if args.updater_outdir:
target_dir = args.updater_outdir
else:
target_dir = os.path.join(mod_path, "data")

# Create directory if it doesn't exist
os.makedirs(target_dir, exist_ok=True)

# check that FASTA files in list all exist
for file in args.fasta:
if not os.path.exists(file):
Expand All @@ -216,6 +241,8 @@ def main(): # pragma: no cover
time_start = time.time()
count, time_elapsed = sierralocal(args.fasta, args.outfile, xml=args.xml,
json=args.json, cleanup=args.cleanup, forceupdate=args.forceupdate,
apobec_csv=args.apobec_csv, unusual_csv=args.unusual_csv,
sdrms_csv=args.sdrms_csv, mutation_csv=args.mutation_csv, updater_outdir=target_dir,
program=args.alignment)
time_diff = time.time() - time_start

Expand Down
Loading