diff --git a/lookups.py b/lookups.py index 1de292b..712423b 100644 --- a/lookups.py +++ b/lookups.py @@ -364,95 +364,119 @@ def get_summary_for_intervalset(db, intervalset): ] -def get_variants_subset_for_intervalset(db, intervalset, columns_to_return, order, filter_info, skip, length): - # 1. match what the user asked for - using [intervalset, filter_info] - # 2. project to just keys for sorting, sort, and get `_id`s - using [order] - # 3. get `n_filtered` and `length`-many `_id`s - using [skip, length] - # 4. look up those `_id`s and project - using [columns_to_return] - st = time.time() - - mongo_match = [intervalset.to_mongo()] - if filter_info.get('filter_value',None) is not None: - if filter_info['filter_value'] == 'PASS': mongo_match.append({'filter': 'PASS'}) - elif filter_info['filter_value'] == 'not PASS': mongo_match.append({'filter': {'$ne': 'PASS'}}) - if isinstance(filter_info.get('maf_ge',None),(float,int)): - assert 0 <= filter_info['maf_ge'] <= 0.5 - if filter_info['maf_ge'] > 0: mongo_match.append({'$and': [{'allele_freq': {'$gte': filter_info['maf_ge']}},{'allele_freq': {'$lte': 1-filter_info['maf_ge']}}]}) - if isinstance(filter_info.get('maf_le',None),(float,int)): - assert 0 <= filter_info['maf_le'] <= 0.5 - if filter_info['maf_le'] < 0.5: mongo_match.append({'$or': [{'allele_freq': {'$lte': filter_info['maf_le']}},{'allele_freq': {'$gte': 1-filter_info['maf_le']}}]}) - if filter_info.get('category',None) is not None: - if filter_info['category'].strip() == 'LoF': mongo_match.append({'worst_csqidx': {'$lt': Consequence.as_obj['n_lof']}}) - elif filter_info['category'].strip() == 'LoF+Missense': mongo_match.append({'worst_csqidx': {'$lt': Consequence.as_obj['n_lof_mis']}}) - - cols = { - # after pre-processing, these will look like: - # : {'sort': {'project': , 'sort_key': }, 'return': {'project': }} - # : {'sort': False, 'return': {'project': }} - 'allele': {'return': ['rsids', 'ref', 'alt']}, - 'pos': {'sort': 'xpos'}, - 'csq': {'sort': 'worst_csqidx', 'return':{'project': { - 'worst_csqidx':1, - 'HGVS':'$worst_csq_HGVS', - 'low_conf': {'$in':[{'k':"LoF",'v':"LC"},{'$objectToArray':{'$arrayElemAt':["$vep_annotations",0]}}]}, # Gross, but I don't know a better way. - }}}, - 'filter': {}, - 'allele_count': {'sort': True}, - 'allele_num': {'sort': True}, +class VariantSubsetFetcher(object): + + # Define the columns that we have available. + # After the substitutions in the `for` loop below, each entry is like: + # { : { 'sort': , 'return': {'project': }}} + # where is either: + # - `False` (meaning that we cannot sort by this column), or + # - { 'project': , 'sort_key': } + allowed_columns = { + 'allele': {'sort': False, + 'return': {'project': {'rsids':True, 'ref':True, 'alt':True}}}, + 'pos': {'sort': {'project':{'xpos':True}, 'sort_key':'xpos'}}, + 'csq': {'sort': {'project':{'worst_csqidx':1}, 'sort_key':'worst_csqidx'}, + 'return':{ 'project': { + 'worst_csqidx':1, + 'HGVS':'$worst_csq_HGVS', + # `.vep_annotations` is sorted descending by `.worst_csq`, so the first object is the most-severe consequence. + # I want to just set `low_conf = (variant.vep_annotations[0].LoF == "LC")`, but I don't know how. + 'low_conf': {'$in':[{'k':"LoF",'v':"LC"},{'$objectToArray':{'$arrayElemAt':["$vep_annotations",0]}}]}}}}, 'het': {'sort': {'project': {'het': {'$subtract':['$allele_count',{'$multiply':[2,'$hom_count']}]}}, 'sort_key': 'het'}, 'return': {'project': {'het': {'$subtract':['$allele_count',{'$multiply':[2,'$hom_count']}]}}}}, 'hom_count': {'sort': True}, + 'allele_count': {'sort': True}, + 'allele_num': {'sort': True}, 'allele_freq': {'sort': True}, 'cadd_phred': {'sort': True}, + 'filter': {'sort': False}, } - for name, col in cols.items(): - try: - if 'sort' not in col: col['sort'] = False - if col['sort'] == True: col['sort'] = name - if isinstance(col['sort'], str): col['sort'] = {'project': {col['sort']:1}, 'sort_key':col['sort']} - assert col['sort'] == False or isinstance(col['sort']['project'], dict) and isinstance(col['sort']['sort_key'], str) - if 'return' not in col: col['return'] = [name] - if isinstance(col['return'], list): col['return'] = {'project': {k:1 for k in col['return']}} - assert isinstance(col['return']['project'], dict) - except: - print('COL = ', col) - raise - - mongo_projection_before_sort = {} - mongo_sort = OrderedDict() - for order_item in order: - direction = {'asc': pymongo.ASCENDING, 'desc':pymongo.DESCENDING}[order_item['dir']] - colidx = order_item['column']; colname = columns_to_return[colidx]['name']; col = cols[colname] - mongo_projection_before_sort.update(col['sort']['project']) - mongo_sort[col['sort']['sort_key']] = direction - - mongo_projection = mkdict(*[cols[ctr['name']]['return']['project'] for ctr in columns_to_return], _id=False) - - v_ids_curs = db.variants.aggregate([ - {'$match': {'$and': mongo_match}}, - {'$project': mongo_projection_before_sort}, - {'$sort': mongo_sort}, - {'$project': {'_id': 1}}, - {'$group': {'_id':0, 'count':{'$sum':1}, 'results':{'$push':'$$ROOT'}}}, - {'$project': {'_id':0, 'count':1, 'ids':{'$slice':['$results',skip,length]}}}, - ]) - print '## VARIANT_SUBSET: spent {:.3f} seconds creating cursor'.format(time.time()-st); st = time.time() - v_ids_result = list(v_ids_curs) - if len(v_ids_result) == 0: - n_filtered, variants = 0, [] - else: - assert len(v_ids_result) == 1 - n_filtered = v_ids_result[0]['count'] - print '## VARIANT_SUBSET: spent {:0.3f} seconds counting {} variants that match filters'.format(time.time()-st, n_filtered); st = time.time() - v_ids = [v['_id'] for v in v_ids_result[0]['ids']] - variants = [next(db.variants.aggregate([{'$match': {'_id': vid}}, {'$project': mongo_projection}])) for vid in v_ids] # b/c fancy projections require .aggregate() - print '## VARIANT_SUBSET: spent {:0.3f} seconds fetching {} full variants by id'.format(time.time()-st, len(variants)); st = time.time() - - return { - 'recordsFiltered': n_filtered, - 'recordsTotal': n_filtered, - 'data': variants - } + for column_name, column_config in allowed_columns.items(): + if column_config['sort'] == True: + column_config['sort'] = {'project': {column_name:True}, 'sort_key':column_name} + if 'return' not in column_config: + column_config['return'] = {'project': {column_name: True}} + if column_config['sort'] != False: + assert isinstance(column_config['sort']['project'], dict) + assert isinstance(column_config['sort']['sort_key'], str) + assert isinstance(column_config['return']['project'], dict) + + @staticmethod + def fetch(db, intervalset, columns_to_return, order, filter_info, skip, length): + ''' + `intervalset` must be an IntervalSet of the intervals to query + `columns_to_return` is a list of columns from VariantSubsetFetcher.allowed_columns.keys() + `order` is like [(column_index, direction), ...], where `column_index` is an index of `columns_to_return` and `direction` is 'desc' or 'asc' + `filter_info` is like {'filter_value':'PASS', 'maf_le':0.25} + `skip` is the number of variants to skip + `length` is the number of variants to return (after possibly skipping some) + + This method works by: + 1. Get all variants that are in an interval from `intervalset` and match `filter_info`. + 2. Project down to only the fields needed for sorting, and keep `_id`. + 3. Sort. + 4. Remove all fields except `_id`, to save memory. + 5. Set `n_filtered` to be the number of variants remaining. + 6. For each `_id`, look up the original variant, apply the projection for each column in `columns_to_return`, and return it. + ''' + + st = time.time() # start_time + + # Build the array to use with `$match` to select the variants to return using `intervalset` and `filter_info` + mongo_match = [intervalset.to_mongo()] + if filter_info.get('filter_value',None) is not None: + if filter_info['filter_value'] == 'PASS': mongo_match.append({'filter': 'PASS'}) + elif filter_info['filter_value'] == 'not PASS': mongo_match.append({'filter': {'$ne': 'PASS'}}) + if isinstance(filter_info.get('maf_ge',None),(float,int)): + assert 0 <= filter_info['maf_ge'] <= 0.5 + if filter_info['maf_ge'] > 0: mongo_match.append({'$and': [{'allele_freq': {'$gte': filter_info['maf_ge']}},{'allele_freq': {'$lte': 1-filter_info['maf_ge']}}]}) + if isinstance(filter_info.get('maf_le',None),(float,int)): + assert 0 <= filter_info['maf_le'] <= 0.5 + if filter_info['maf_le'] < 0.5: mongo_match.append({'$or': [{'allele_freq': {'$lte': filter_info['maf_le']}},{'allele_freq': {'$gte': 1-filter_info['maf_le']}}]}) + if filter_info.get('category',None) is not None: + if filter_info['category'].strip() == 'LoF': mongo_match.append({'worst_csqidx': {'$lt': Consequence.as_obj['n_lof']}}) + elif filter_info['category'].strip() == 'LoF+Missense': mongo_match.append({'worst_csqidx': {'$lt': Consequence.as_obj['n_lof_mis']}}) + + mongo_projection_before_sort = {} + mongo_sort = OrderedDict() + for order_item in order: + direction = {'asc':pymongo.ASCENDING, 'desc':pymongo.DESCENDING}[order_item['dir']] + column_name = columns_to_return[order_item['column']]['name'] + column_config = VariantSubsetFetcher.allowed_columns[colname] + mongo_projection_before_sort.update(column_config['sort']['project']) + mongo_sort[column_config['sort']['sort_key']] = direction + + mongo_projection = {'_id': False} + for column_name in columns_to_return: + mongo_projection.update(VariantSubsetFetcher.allowed_columns[column_name]['return']['project']) + + results_cursor = db.variants.aggregate([ + {'$match': {'$and': mongo_match}}, + {'$project': mongo_projection_before_sort}, + {'$sort': mongo_sort}, + {'$project': {'_id': 1}}, + {'$group': {'_id':0, 'count':{'$sum':1}, 'results':{'$push':'$$ROOT'}}}, + {'$project': {'_id':0, 'count':1, 'ids':{'$slice':['$results',skip,length]}}}, + ]) + print '## VARIANT_SUBSET: spent {:.3f} seconds creating cursor'.format(time.time()-st); st = time.time() + results = list(results_cursor) + if len(results) == 0: + n_filtered, variants = 0, [] + else: + assert len(results) == 1 + n_filtered = results[0]['count'] + print '## VARIANT_SUBSET: spent {:0.3f} seconds counting {} variants that match filters'.format(time.time()-st, n_filtered); st = time.time() + variant_ids = [v['_id'] for v in results[0]['ids']] + # `db.variants.find(...)` cannot handle complex projections, so we use `next(db.variants.aggregate(...))` instead. + variants = [next(db.variants.aggregate([{'$match': {'_id': v_id}}, {'$project': mongo_projection}])) for v_id in variant_ids] + print '## VARIANT_SUBSET: spent {:0.3f} seconds fetching {} full variants by id'.format(time.time()-st, len(variants)); st = time.time() + + return { + 'recordsFiltered': n_filtered, + 'recordsTotal': n_filtered, + 'data': variants + } def get_variants_csv_str_for_intervalset(db, intervalset):