Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 63 additions & 49 deletions pysvtools/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
print("No PyVCF installation was found, please install with:\n\tpip install pyvcf")
sys.exit(1)

logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

from pysvtools.models import Event
from pysvtools.utils import extractTXmate, extractDPFromRecord, getSVType, getSVLEN, formatBedTrack, \
from pysvtools.utils import extractTXmate, extractDPFromRecord, extractGTFromRecord, getSVType, getSVLEN, formatBedTrack, \
formatVCFRecord, vcfHeader, build_exclusion


Expand All @@ -50,7 +50,7 @@ def __init__(self):


# read all samples in memory
def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmethod=""):
def loadEventFromVCF(sampleName, vcf_reader, edb, centerpointFlanking, transonly, svmethod=""):
"""
Loading VCF records and transform them to `Event`
"""
Expand Down Expand Up @@ -82,6 +82,7 @@ def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmetho
end, sv_type="TRA",
cp_flank=centerpointFlanking,
dp=extractDPFromRecord(record),
gt=extractGTFromRecord(record),
svmethod=svmethod)
svDB[t.virtualChr] = svDB.get(t.virtualChr, [])
svDB[t.virtualChr].append(t)
Expand All @@ -94,6 +95,7 @@ def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmetho
sv_type='TRA',
cp_flank=centerpointFlanking,
dp=extractDPFromRecord(record),
gt=extractGTFromRecord(record),
svmethod=svmethod)
try:
svDB[t.virtualChr] = svDB.get(t.virtualChr, [])
Expand All @@ -120,6 +122,7 @@ def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmetho
sv_type=SVTYPE,
cp_flank=centerpointFlanking,
dp=extractDPFromRecord(record),
gt=extractGTFromRecord(record),
svmethod=svmethod)
except:
print("Unexpected error:", sys.exc_info()[0])
Expand All @@ -146,6 +149,7 @@ def loadEventFromVCF(s, vcf_reader, edb, centerpointFlanking, transonly, svmetho
sv_type=SVTYPE,
cp_flank=centerpointFlanking,
dp=extractDPFromRecord(record),
gt=extractGTFromRecord(record),
svmethod=svmethod)
except:
print("Unexpected error:", sys.exc_info()[0])
Expand All @@ -162,9 +166,15 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b
regions_out_file = open(regions_out, "w")
vcf_output_file = open(vcf_output, "w")

samplelist = vcf_files
# sampleList stores paths
sampleList = vcf_files
# sampleNames stores the real samplenames (now based on filenames only) Only single sample/vcf support now.
# TODO: extract sample name from the samplename columns in the vcf
sampleNames = [ os.path.basename(x).replace(".vcf", "").replace(".realign", "").replace(".baserecal", "").replace(".dedup", "") for x in vcf_files ]

# sampleDB stores the file read handlers to the vcf files
sampleDB = collections.OrderedDict()
# svDB stores the actual content / parsed contents of the sv calls
svDB = collections.OrderedDict()

commonhits = collections.OrderedDict()
Expand All @@ -174,15 +184,16 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b
for exclusion_region in exclusion_regions:
edb += build_exclusion(exclusion_region)

for s in samplelist:
logger.info('Reading SV-events from sample: {} '.format(s))
sampleDB[s] = vcf.Reader(open(s, 'r'))
for i, samplePath in enumerate(sampleList):
sampleName = sampleNames[i]
logger.info('Reading SV-events from sample: {} '.format(samplePath))
sampleDB[sampleName] = vcf.Reader(open(samplePath, 'r'))

# extract SV caller from header
sv_caller = sampleDB[s].metadata.get('source', [os.path.basename(s).strip('.vcf')]).pop(0).split(' ').pop(0)
sv_caller = sampleDB[sampleName].metadata.get('source', [os.path.basename(samplePath).strip('.vcf')]).pop(0).split(' ').pop(0)

svDB[s] = loadEventFromVCF(s, sampleDB[s], edb, centerpointFlanking, transonly, sv_caller)
n_events = sum([len(calls) for chromlist, calls in svDB[s].items()])
svDB[sampleName] = loadEventFromVCF(sampleName, sampleDB[sampleName], edb, centerpointFlanking, transonly, sv_caller)
n_events = sum([len(calls) for chromlist, calls in svDB[sampleName].items()])
logger.info('Loaded SV-events from sample: {} '.format(n_events))

pairs_to_check = itertools.combinations(svDB.keys(), 2)
Expand All @@ -195,9 +206,9 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b
chromosomes_to_check.sort()

for (s1, s2) in pairs_to_check:
_s1 = os.path.basename(s1)
_s2 = os.path.basename(s2)
logger.debug('Pairwise compare: {} x {}'.format(_s1, _s2))
# _s1 = os.path.basename(s1)
# _s2 = os.path.basename(s2)
logger.debug('Pairwise compare: {} x {}'.format(s1, s2))
for _chromosome in chromosomes_to_check:
s1_calls_in_chromosome = svDB[s1].get(_chromosome, [])
s2_calls_in_chromosome = svDB[s2].get(_chromosome, [])
Expand All @@ -206,46 +217,46 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b

_match = 0

for i, t1 in enumerate(s1_calls_in_chromosome):
for j, t2 in enumerate(s2_calls_in_chromosome):
if not (t1 and t2):
for i, eventA in enumerate(s1_calls_in_chromosome):
for j, eventB in enumerate(s2_calls_in_chromosome):
if not (eventA and eventB):
continue
if t1 == t2:
if eventA == eventB:
# determine the object with the most DP and size
if t1.size >= t2.size:
_m = t1
if eventA.size >= eventB.size:
referenceEvent = eventA
else:
_m = t2
referenceEvent = eventB

# get the hashes from all hits, check weither one of them was already evaluated and thus in the table

if t1.matched_in or t2.matched_in:
m = t1.matched_in or t2.matched_in
if eventA.matched_in or eventB.matched_in:
matchedInObject = eventA.matched_in or eventB.matched_in
else:
m = _m.hexdigest
matchedInObject = referenceEvent.hexdigest

t1.matched_in = m
t2.matched_in = m
eventA.matched_in = matchedInObject
eventB.matched_in = matchedInObject

# TODO: write getter method for the Event instead now by accessing the internal class variable
virtualchrom = _m.virtualChr
virtualChromosome = referenceEvent.virtualChr

# split out per chromosome storage
commonhits[virtualchrom] = commonhits.get(virtualchrom, collections.OrderedDict())
commonhits[virtualchrom][m] = commonhits[virtualchrom].get(m, collections.OrderedDict())
commonhits[virtualchrom][m][s1] = t1
commonhits[virtualchrom][m][s2] = t2
commonhits[virtualChromosome] = commonhits.get(virtualChromosome, collections.OrderedDict())
commonhits[virtualChromosome][matchedInObject] = commonhits[virtualChromosome].get(matchedInObject, collections.OrderedDict())
commonhits[virtualChromosome][matchedInObject][s1] = eventA
commonhits[virtualChromosome][matchedInObject][s2] = eventB
_match += 1

# when found, continue? It is the best approach to break this long list comparison?
break
logger.debug("Common hits so far in {}: {} / {} vs {}".format(_chromosome, _match, s1_n_calls, s2_n_calls))

# vcf header
print(vcfHeader(), file=vcf_output_file)
print(vcfHeader(sampleList=sampleNames), file=vcf_output_file)

# tsv file
samplecols = "\t".join(map(lambda x: "{}\tsize".format(os.path.basename(x).strip(".vcf")), samplelist))
samplecols = "\t".join(map(lambda x: "{}\tsize".format(os.path.basename(x).strip(".vcf")), sampleList))
header_line = "\t".join(['ChrA', 'ChrApos', 'ChrB', 'ChrBpos', 'SVTYPE', 'DP', 'Size', samplecols])

tsv_report_output = open(output_file, 'w')
Expand All @@ -255,34 +266,37 @@ def startMerge(vcf_files, exclusion_regions, output_file, centerpointFlanking, b
all_locations = []

for virtualChr in natsorted(commonhits.keys()):
for s, items in sorted(commonhits[virtualChr].items(), key=lambda hit: hit[1].items()[0][1].chrApos):
if len(items):
for sampleName, hits in sorted(commonhits[virtualChr].items(), key=lambda hit: hit[1].items()[0][1].chrApos):
if len(hits):
# check which samples has the same
locations_found = []
for sample in samplelist:
if sample in items.keys():
locations_found.append("{}\t{}".format(items[sample], items[sample].size))
for sample in sampleNames:
if sample in hits.keys():
locations_found.append("{}\t{}".format(hits[sample], hits[sample].size))
# track all locations found for later intersecting or complementing the set of found/not-found
all_locations.append(items[sample])
all_locations.append(hits[sample])
else:
locations_found.append("\t")
# get the key with the highest DP
sorted_by_dp = sorted(items.items(), key=lambda hit: hit[1].dp, reverse=True)
sorted_by_dp = sorted(hits.items(), key=lambda hit: hit[1].dp, reverse=True)
fKey = sorted_by_dp[0][0]
t = items[fKey]
outputRecord = hits[fKey]

print(formatVCFRecord(t), file=vcf_output_file)
logger.debug(outputRecord)
logger.debug(hits)

print(formatVCFRecord(outputRecord, hits, sampleNames), file=vcf_output_file)

tsv_report_output.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format(
t.chrA,
t.chrApos,
t.chrB,
t.chrBpos,
t.sv_type,
t.dp,
t.size,
outputRecord.chrA,
outputRecord.chrApos,
outputRecord.chrB,
outputRecord.chrBpos,
outputRecord.sv_type,
outputRecord.dp,
outputRecord.size,
"\t".join(locations_found)))
bed_structural_events.write(formatBedTrack(t))
bed_structural_events.write(formatBedTrack(outputRecord))

for loc in all_locations:
print(loc.bedRow, file=regions_out_file)
Expand Down
Loading