diff --git a/pysvtools/merge.py b/pysvtools/merge.py index 7afca85..0101e51 100755 --- a/pysvtools/merge.py +++ b/pysvtools/merge.py @@ -22,11 +22,11 @@ print("No PyVCF installation was found, please install with:\n\tpip install pyvcf") sys.exit(1) -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) from pysvtools.models import Event -from pysvtools.utils import extractTXmate, extractDPFromRecord, getSVType, getSVLEN, formatBedTrack, \ +from pysvtools.utils import extractTXmate, extractDPFromRecord, extractGTFromRecord, getSVType, getSVLEN, formatBedTrack, \ formatVCFRecord, vcfHeader, build_exclusion @@ -50,7 +50,7 @@ def __init__(self): # read all samples in memory -def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmethod=""): +def loadEventFromVCF(sampleName, vcf_reader, edb, centerpointFlanking, transonly, svmethod=""): """ Loading VCF records and transform them to `Event` """ @@ -82,6 +82,7 @@ def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmetho end, sv_type="TRA", cp_flank=centerpointFlanking, dp=extractDPFromRecord(record), + gt=extractGTFromRecord(record), svmethod=svmethod) svDB[t.virtualChr] = svDB.get(t.virtualChr, []) svDB[t.virtualChr].append(t) @@ -94,6 +95,7 @@ def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmetho sv_type='TRA', cp_flank=centerpointFlanking, dp=extractDPFromRecord(record), + gt=extractGTFromRecord(record), svmethod=svmethod) try: svDB[t.virtualChr] = svDB.get(t.virtualChr, []) @@ -120,6 +122,7 @@ def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmetho sv_type=SVTYPE, cp_flank=centerpointFlanking, dp=extractDPFromRecord(record), + gt=extractGTFromRecord(record), svmethod=svmethod) except: print("Unexpected error:", sys.exc_info()[0]) @@ -146,6 +149,7 @@ def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmetho sv_type=SVTYPE, cp_flank=centerpointFlanking, dp=extractDPFromRecord(record), + gt=extractGTFromRecord(record), svmethod=svmethod) except: print("Unexpected error:", sys.exc_info()[0]) @@ -162,9 +166,15 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b regions_out_file = open(regions_out, "w") vcf_output_file = open(vcf_output, "w") - samplelist = vcf_files + # sampleList stores paths + sampleList = vcf_files + # sampleNames stores the real samplenames (now based on filenames only) Only single sample/vcf support now. + # TODO: extract sample name from the samplename columns in the vcf + sampleNames = [ os.path.basename(x).replace(".vcf", "").replace(".realign", "").replace(".baserecal", "").replace(".dedup", "") for x in vcf_files ] + # sampleDB stores the file read handlers to the vcf files sampleDB = collections.OrderedDict() + # svDB stores the actual content / parsed contents of the sv calls svDB = collections.OrderedDict() commonhits = collections.OrderedDict() @@ -174,15 +184,16 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b for exclusion_region in exclusion_regions: edb += build_exclusion(exclusion_region) - for s in samplelist: - logger.info('Reading SV-events from sample: {} '.format(s)) - sampleDB[s] = vcf.Reader(open(s, 'r')) + for i, samplePath in enumerate(sampleList): + sampleName = sampleNames[i] + logger.info('Reading SV-events from sample: {} '.format(samplePath)) + sampleDB[sampleName] = vcf.Reader(open(samplePath, 'r')) # extract SV caller from header - sv_caller = sampleDB[s].metadata.get('source', [os.path.basename(s).strip('.vcf')]).pop(0).split(' ').pop(0) + sv_caller = sampleDB[sampleName].metadata.get('source', [os.path.basename(samplePath).strip('.vcf')]).pop(0).split(' ').pop(0) - svDB[s] = loadEventFromVCF(s, sampleDB[s], edb, centerpointFlanking, transonly, sv_caller) - n_events = sum([len(calls) for chromlist, calls in svDB[s].items()]) + svDB[sampleName] = loadEventFromVCF(sampleName, sampleDB[sampleName], edb, centerpointFlanking, transonly, sv_caller) + n_events = sum([len(calls) for chromlist, calls in svDB[sampleName].items()]) logger.info('Loaded SV-events from sample: {} '.format(n_events)) pairs_to_check = itertools.combinations(svDB.keys(), 2) @@ -195,9 +206,9 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b chromosomes_to_check.sort() for (s1, s2) in pairs_to_check: - _s1 = os.path.basename(s1) - _s2 = os.path.basename(s2) - logger.debug('Pairwise compare: {} x {}'.format(_s1, _s2)) + # _s1 = os.path.basename(s1) + # _s2 = os.path.basename(s2) + logger.debug('Pairwise compare: {} x {}'.format(s1, s2)) for _chromosome in chromosomes_to_check: s1_calls_in_chromosome = svDB[s1].get(_chromosome, []) s2_calls_in_chromosome = svDB[s2].get(_chromosome, []) @@ -206,35 +217,35 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b _match = 0 - for i, t1 in enumerate(s1_calls_in_chromosome): - for j, t2 in enumerate(s2_calls_in_chromosome): - if not (t1 and t2): + for i, eventA in enumerate(s1_calls_in_chromosome): + for j, eventB in enumerate(s2_calls_in_chromosome): + if not (eventA and eventB): continue - if t1 == t2: + if eventA == eventB: # determine the object with the most DP and size - if t1.size >= t2.size: - _m = t1 + if eventA.size >= eventB.size: + referenceEvent = eventA else: - _m = t2 + referenceEvent = eventB # get the hashes from all hits, check weither one of them was already evaluated and thus in the table - if t1.matched_in or t2.matched_in: - m = t1.matched_in or t2.matched_in + if eventA.matched_in or eventB.matched_in: + matchedInObject = eventA.matched_in or eventB.matched_in else: - m = _m.hexdigest + matchedInObject = referenceEvent.hexdigest - t1.matched_in = m - t2.matched_in = m + eventA.matched_in = matchedInObject + eventB.matched_in = matchedInObject # TODO: write getter method for the Event instead now by accessing the internal class variable - virtualchrom = _m.virtualChr + virtualChromosome = referenceEvent.virtualChr # split out per chromosome storage - commonhits[virtualchrom] = commonhits.get(virtualchrom, collections.OrderedDict()) - commonhits[virtualchrom][m] = commonhits[virtualchrom].get(m, collections.OrderedDict()) - commonhits[virtualchrom][m][s1] = t1 - commonhits[virtualchrom][m][s2] = t2 + commonhits[virtualChromosome] = commonhits.get(virtualChromosome, collections.OrderedDict()) + commonhits[virtualChromosome][matchedInObject] = commonhits[virtualChromosome].get(matchedInObject, collections.OrderedDict()) + commonhits[virtualChromosome][matchedInObject][s1] = eventA + commonhits[virtualChromosome][matchedInObject][s2] = eventB _match += 1 # when found, continue? It is the best approach to break this long list comparison? @@ -242,10 +253,10 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b logger.debug("Common hits so far in {}: {} / {} vs {}".format(_chromosome, _match, s1_n_calls, s2_n_calls)) # vcf header - print(vcfHeader(), file=vcf_output_file) + print(vcfHeader(sampleList=sampleNames), file=vcf_output_file) # tsv file - samplecols = "\t".join(map(lambda x: "{}\tsize".format(os.path.basename(x).strip(".vcf")), samplelist)) + samplecols = "\t".join(map(lambda x: "{}\tsize".format(os.path.basename(x).strip(".vcf")), sampleList)) header_line = "\t".join(['ChrA', 'ChrApos', 'ChrB', 'ChrBpos', 'SVTYPE', 'DP', 'Size', samplecols]) tsv_report_output = open(output_file, 'w') @@ -255,34 +266,37 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b all_locations = [] for virtualChr in natsorted(commonhits.keys()): - for s, items in sorted(commonhits[virtualChr].items(), key=lambda hit: hit[1].items()[0][1].chrApos): - if len(items): + for sampleName, hits in sorted(commonhits[virtualChr].items(), key=lambda hit: hit[1].items()[0][1].chrApos): + if len(hits): # check which samples has the same locations_found = [] - for sample in samplelist: - if sample in items.keys(): - locations_found.append("{}\t{}".format(items[sample], items[sample].size)) + for sample in sampleNames: + if sample in hits.keys(): + locations_found.append("{}\t{}".format(hits[sample], hits[sample].size)) # track all locations found for later intersecting or complementing the set of found/not-found - all_locations.append(items[sample]) + all_locations.append(hits[sample]) else: locations_found.append("\t") # get the key with the highest DP - sorted_by_dp = sorted(items.items(), key=lambda hit: hit[1].dp, reverse=True) + sorted_by_dp = sorted(hits.items(), key=lambda hit: hit[1].dp, reverse=True) fKey = sorted_by_dp[0][0] - t = items[fKey] + outputRecord = hits[fKey] - print(formatVCFRecord(t), file=vcf_output_file) + logger.debug(outputRecord) + logger.debug(hits) + + print(formatVCFRecord(outputRecord, hits, sampleNames), file=vcf_output_file) tsv_report_output.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format( - t.chrA, - t.chrApos, - t.chrB, - t.chrBpos, - t.sv_type, - t.dp, - t.size, + outputRecord.chrA, + outputRecord.chrApos, + outputRecord.chrB, + outputRecord.chrBpos, + outputRecord.sv_type, + outputRecord.dp, + outputRecord.size, "\t".join(locations_found))) - bed_structural_events.write(formatBedTrack(t)) + bed_structural_events.write(formatBedTrack(outputRecord)) for loc in all_locations: print(loc.bedRow, file=regions_out_file) diff --git a/pysvtools/mergefaster.py b/pysvtools/mergefaster.py index 1b91fd8..8d9e5e3 100644 --- a/pysvtools/mergefaster.py +++ b/pysvtools/mergefaster.py @@ -2,6 +2,8 @@ from __future__ import print_function +from merge import loadEventFromVCF + __desc__ = """ Merging procedure for Structural Variation events. Follows the idea of centerpoint matching to allow flexible match vs. reciprocal overlap. @@ -45,7 +47,170 @@ class ReportExport(object): def __init__(self): pass - +def matchEvents(sampleName, s2_calls_in_chromosome, allhits, virtualChromosome): + s1_calls_in_chromosome = allhits.get(virtualChromosome, []) + s1_n_calls = len(s1_calls_in_chromosome) + s2_n_calls = len(s2_calls_in_chromosome) + _match = 0 + + for i, eventA in enumerate(s2_calls_in_chromosome): + # for all events in the new "list" + # find a match in the overall list, try to match or add + + matched = False + for j, matchKey in enumerate(s1_calls_in_chromosome): + _eventB = [x for x in s1_calls_in_chromosome[matchKey].values() if x.hexdigest == matchKey] + if len(_eventB): + eventB = _eventB[0] + else: + print([x.hexdigest for x in s1_calls_in_chromosome[matchKey].values()]) + print([x for x in s1_calls_in_chromosome[matchKey].values()]) + + print(matchKey) + print(_eventB) + + if not (eventA and eventB): + continue + if eventA == eventB: + matchedInObject = matchKey + # split out per chromosome storage + allhits[virtualChromosome] = allhits.get(virtualChromosome, collections.OrderedDict()) + allhits[virtualChromosome][matchedInObject] = allhits[virtualChromosome].get(matchedInObject, + collections.OrderedDict()) + allhits[virtualChromosome][matchedInObject][sampleName] = eventA + _match += 1 + + # when found, continue? It is the best approach to break this long list comparison? + matched = True + break + + if not matched: + # add eventA to 'allhits[virtualChromosome]' + if eventA.virtualChr == "chr17chr17": + print("{} - {}".format(eventA.hexdigest, eventA)) + matchedInObject = eventA.hexdigest + eventA.matched_in = matchedInObject + virtualChromosome = eventA.virtualChr + allhits[virtualChromosome] = allhits.get(virtualChromosome, collections.OrderedDict()) + allhits[virtualChromosome][matchedInObject] = allhits[virtualChromosome].get(matchedInObject, + collections.OrderedDict()) + allhits[virtualChromosome][matchedInObject][sampleName] = eventA + # logger.debug("Adding new match in {} now {} items".format( + # virtualChromosome, + # len(allhits[virtualChromosome]) + # )) + + logger.debug("Common hits so far in {}: {} / {} vs {}".format(virtualChromosome, _match, s1_n_calls, s2_n_calls)) + return allhits[virtualChromosome] + + +def report(sampleNames, vcf_output_file, tsv_output_file, bed_output_file, commonhits, regions_out_file): + # vcf header + print(vcfHeader(sampleList=sampleNames), file=vcf_output_file) + + # tsv file + samplecols = "\t".join(map(lambda x: "{}\tsize".format(x), sampleNames)) + header_line = "\t".join(['ChrA', 'ChrApos', 'ChrB', 'ChrBpos', 'SVTYPE', 'DP', 'Size', samplecols]) + + tsv_report_output = open(tsv_output_file, 'w') + tsv_report_output.write("{}\n".format(header_line)) + + bed_structural_events = open(bed_output_file, 'w') + all_locations = [] + + for virtualChr in natsorted(commonhits.keys()): + for sampleName, hits in sorted(commonhits[virtualChr].items(), key=lambda hit: hit[1].items()[0][1].chrApos): + if len(hits): + # check which samples has the same + locations_found = [] + for sample in sampleNames: + if sample in hits.keys(): + locations_found.append("{}\t{}".format(hits[sample], hits[sample].size)) + # track all locations found for later intersecting or complementing the set of found/not-found + all_locations.append(hits[sample]) + else: + locations_found.append("\t") + # get the key with the highest DP + sorted_by_dp = sorted(hits.items(), key=lambda hit: hit[1].dp, reverse=True) + fKey = sorted_by_dp[0][0] + outputRecord = hits[fKey] + + logger.debug(outputRecord) + logger.debug(hits) + + print(formatVCFRecord(outputRecord, hits, sampleNames), file=vcf_output_file) + + tsv_report_output.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format( + outputRecord.chrA, + outputRecord.chrApos, + outputRecord.chrB, + outputRecord.chrBpos, + outputRecord.sv_type, + outputRecord.dp, + outputRecord.size, + "\t".join(locations_found))) + bed_structural_events.write(formatBedTrack(outputRecord)) + + for loc in all_locations: + print(loc.bedRow, file=regions_out_file) + + tsv_report_output.close() + regions_out_file.close() + +def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, bedoutput, transonly=False, + regions_out="regions_out.bed", vcf_output="output.vcf"): + regions_out_file = open(regions_out, "w") + vcf_output_file = open(vcf_output, "w") + + # sampleList stores paths + sampleList = vcf_files + # sampleNames stores the real samplenames (now based on filenames only) Only single sample/vcf support now. + # TODO: extract sample name from the samplename columns in the vcf + sampleNames = [ + os.path.basename(x).replace(".vcf", "").replace(".realign", "").replace(".baserecal", "").replace(".dedup", "") + for x in vcf_files] + + # sampleDB stores the file read handlers to the vcf files + sampleDB = collections.OrderedDict() + # svDB stores the actual content / parsed contents of the sv calls + svDB = collections.OrderedDict() + + allhits = collections.OrderedDict() + chromosomes_to_check = [] + + + edb = [] + if type(exclusion_regions) != type([]): + exclusion_regions = [] + for exclusion_region in exclusion_regions: + edb += build_exclusion(exclusion_region) + + for i, samplePath in enumerate(sampleList): + sampleName = sampleNames[i] + logger.info('Reading SV-events from sample: {} '.format(samplePath)) + sampleDB[sampleName] = vcf.Reader(open(samplePath, 'r')) + + # extract SV caller from header + sv_caller = sampleDB[sampleName].metadata.get('source', [os.path.basename(samplePath).strip('.vcf')]).pop( + 0).split(' ').pop(0) + + svDB[sampleName] = loadEventFromVCF(sampleName, sampleDB[sampleName], edb, centerpointFlanking, transonly, + sv_caller) + n_events = sum([len(calls) for chromlist, calls in svDB[sampleName].items()]) + logger.info('Loaded SV-events from sample: {} '.format(n_events)) + + # these are virtual chromosomes already + chromosomes_to_check += svDB[sampleName].keys() + chromosomes_to_check = list(set(chromosomes_to_check)) + chromosomes_to_check.sort() + + for _chromosome in chromosomes_to_check: + s2_calls_in_chromosome = svDB[sampleName].get(_chromosome, []) + allhits[_chromosome] = matchEvents(sampleName, s2_calls_in_chromosome, allhits, _chromosome) + + del svDB[sampleName] + + report(sampleNames, vcf_output_file, output_file, bedoutput, allhits, regions_out_file) def main(): parser = argparse.ArgumentParser() diff --git a/pysvtools/models/event.py b/pysvtools/models/event.py index 025ada6..008486c 100644 --- a/pysvtools/models/event.py +++ b/pysvtools/models/event.py @@ -3,14 +3,14 @@ from __future__ import print_function import hashlib -__desc__ = """""" +__desc__ = """Describing a Structural Variation Event based on specifications of VCF v4.2""" __author__ = "Wai Yi Leung " class Event(object): centerpoint_flanking = 100 - def __init__(self, chrA, chrApos, chrB, chrBpos, sv_type=None, cp_flank=None, dp=0, svmethod=""): + def __init__(self, chrA, chrApos, chrB, chrBpos, sv_type=None, cp_flank=None, dp=0, gt="", svmethod="", *args, **kwargs): (self.chrA, self.chrApos), (self.chrB, self.chrBpos) = sorted([(chrA, chrApos), (chrB, chrBpos)]) self.chrApos = int(self.chrApos) @@ -28,18 +28,37 @@ def __init__(self, chrA, chrApos, chrB, chrBpos, sv_type=None, cp_flank=None, dp _vChr.sort() self.virtualChr = "".join(_vChr) - self.centerpoint = self.get_centerpoint self.sv_type = sv_type self.centerpointFlanking = cp_flank or self.centerpoint_flanking + + self.gt = gt self.dp = dp - self._hash = None + self._hash = hashlib.sha1(str(self).encode('utf-8')).hexdigest() + + @property + def svmethod(self): + if self._svmethod.startswith("clever"): + return "clever" + elif self._svmethod.startswith("breakdancer"): + return "breakdancer" + return self._svmethod + + @svmethod.setter + def svmethod(self, value): + self._svmethod = value + + @svmethod.deleter + def svmethod(self): + del self._svmethod @property - def get_centerpoint(self): - cnt = int(self.size / 2) + def centerpoint(self): + center = int(self.size / 2) + # order the positions ascending positions = [self.chrApos, self.chrBpos] positions.sort() - centerpoint = positions[0] + cnt + + centerpoint = positions[0] + center return centerpoint @property diff --git a/pysvtools/utils.py b/pysvtools/utils.py index 8fbe9bb..cdf00c4 100644 --- a/pysvtools/utils.py +++ b/pysvtools/utils.py @@ -32,7 +32,7 @@ def extractTXmate(record): return [chrB, chrBpos] -def vcfHeader(): +def vcfHeader(sampleList=["default"]): ts_now = datetime.datetime.now() vcf_header = """##fileformat=VCFv4.1 ##fileDate={filedate} @@ -72,7 +72,7 @@ def vcfHeader(): ##INFO= ##INFO=""".format(filedate=ts_now.strftime("%Y%m%d"), version=__version__) - return vcf_header + "\n" + "#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT default" + return vcf_header + "\n" + "#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT {samples}".format(samples="\t".join(sampleList)) def extractTXmateINFOFIELD(breakpoints): @@ -101,6 +101,16 @@ def extractDPFromRecord(record): return 0 +def extractGTFromRecord(record): + if 'GT' in record.INFO.keys(): + if type(record.INFO['GT']) == type(list): + return record.INFO['GT'][0] + return record.INFO['GT'] + elif len(record.samples): + return getattr(record.samples[0].data, 'GT', 0) + return "1/." + + def firstFromList(arr): if type(arr) == type([]): return arr[0] @@ -169,7 +179,7 @@ def formatBedTrack(mergedHit): return formatted_bed -def formatVCFRecord(mergedHit): +def formatVCFRecord(mergedHit, hits, sampleNames): # TODO: write the DP for each of the callers/sample INFOFIELDS = "IMPRECISE;SVTYPE={};CHR2={};END={};SVMETHOD={svmethod}".format( mergedHit.sv_type, @@ -177,11 +187,22 @@ def formatVCFRecord(mergedHit): mergedHit.chrBpos, svmethod=mergedHit.svmethod ) - FORMATFIELDS = ":".join(map(str, [ - '1/.', - mergedHit.dp])) - formattedVCFRecord = "{chrA}\t{pos}\t{id}\t{ref}\t<{alt}>\t{qual}\t{filter}\t{info}\tGT:DP\t{format}".format( + FORMATFIELDS_SAMPLES = [] + + for sample in sampleNames: + if sample in hits.keys(): + FORMATFIELDS_SAMPLES.append("{gt}:{dp}:{start}:{end}:{svmethod}".format( + gt=hits[sample].gt, + dp=hits[sample].dp, + start=hits[sample].chrApos, + end=hits[sample].chrBpos, + svmethod=hits[sample].svmethod + )) + else: + FORMATFIELDS_SAMPLES.append("./.:0:0:0:None") + + formattedVCFRecord = "{chrA}\t{pos}\t{id}\t{ref}\t<{alt}>\t{qual}\t{filter}\t{info}\t{formattypes}\t{format}".format( chrA=mergedHit.chrA, pos=mergedHit.chrApos, id='.', @@ -190,7 +211,8 @@ def formatVCFRecord(mergedHit): qual='100', filter='PASS', info=INFOFIELDS, - format=FORMATFIELDS + formattypes="GT:DP:START:END:SVMETHOD", + format="\t".join(FORMATFIELDS_SAMPLES) ) return formattedVCFRecord diff --git a/requirements.txt b/requirements.txt index 5c35061..434222d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ natsort -PyVCF==0.6.7 \ No newline at end of file +PyVCF==0.6.8