diff --git a/.gitignore b/.gitignore index 7b7b098..e60d1e3 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,9 @@ data/D105heartW632_361712_originally_labeled_D120/ Users data/D105skinW6_32_305183/ hla_data.js -hla_data.json \ No newline at end of file +hla_data.json + +#ignore extracted data +data/ref/* + +uMAP.py \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 6f3a291..f673a71 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,3 @@ { - "liveServer.settings.port": 5501 + "liveServer.settings.port": 5502 } \ No newline at end of file diff --git a/README.md b/README.md index 71b0429..c43f925 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,6 @@ cd HLA-PepClust/ .\hlapepclust-env\Scripts\activate ``` -3. **Upgrade `pip`** - ```bash - pip install --upgrade pip - ``` - ## Installing Dependencies 1. **Navigate to the project directory** (if not already in it) @@ -87,32 +82,36 @@ clust-search \ ```bash clust-search data/D90_HLA_3844874 data/ref_data/Gibbs_motifs_human/output_matrices_human \ - --hla_types A0201,A0101,B1302,B3503,C0401 \ - --n_clusters 6 \ - --species human \ - --output test_results \ - --processes 4 \ - --threshold 0.6 + --hla_types A0201,A0101,B1302,B3503,C0401 \ # Specify list of HLA alleles to search + --n_clusters 6 \ # Restrict analysis to 6 Gibbs clusters + --species human \ # Species to evaluate [human, mouse] + --output My_results_directory \ # Directory where results will be saved + --processes 4 \ # Number of parallel processes to use + --threshold 0.7 \ # Correlation threshold for motif matching + --topNHits 3 # Report top-N HLA matches for each Gibbs motif ``` ## Command Line Arguments +| Argument | Type | Description | Default | +|------------------|---------|-----------------------------------------------------------------------------|--------------------------| +| `gibbs_folder` | `str` | Path to test folder containing matrices. | *Required* | +| `reference_folder` | `str` | Path to reference folder containing matrices. | *Required* | +| `-o, --output` | `str` | Path to output folder. | `"output"` | +| `-hla, --hla_types` | `list` | List of HLA types to search. | *All* | +| `-p, --processes` | `int` | Number of parallel processes to use. | `4` | +| `-n, --n_clusters` | `str` | Number of clusters to search for. | `"all"` | +| `-t, --threshold` | `float` | Motif similarity threshold. | `0.70` | +| `-s, --species` | `str` | Species to search [Human, Mouse]. | `"human"` | +| `-db, --database` | `str` | Generate a motif database from a configuration file. | `"data/config.json"` | +| `-st, --Searchtype` | `str` | Type of search to perform [Numba, IO]. | `"Numba"` | +| `-k, --best_KL` | `bool` | Find the best KL divergence only. | `False` | +| `--topNHits` | `int` | Number of top hits to retain per Gibbs matrix. | `3` | +| `-l, --log` | `bool` | Enable logging. | `False` | +| `-im, --immunolyser` | `bool` | Enable immunolyser output. | `False` | +| `-npDB, --NumbaDB` | `str` | Path to the Array database folder. | `"data/ref_data/human_db"` | +| `-c, --credits` | `bool` | Show credits for the motif database pipeline. | `False` | +| `-v, --version` | `bool` | Show the version of the pipeline. | `False` | -| Argument | Type | Description | Default | -|----------|------|-------------|---------| -| `gibbs_folder` | `str` | Path to test folder containing matrices. | *Required* | -| `reference_folder` | `str` | Path to reference folder containing matrices. | *Required* | -| `-o, --output` | `str` | Path to output folder. | `"output"` | -| `-hla, --hla_types` | `list` | List of HLA types to search. | *All* | -| `-p, --processes` | `int` | Number of parallel processes to use. | `4` | -| `-n, --n_clusters` | `int` | Number of clusters to search for. | `"all"` | -| `-t, --threshold` | `float` | Motif similarity threshold. | `0.5` | -| `-s, --species` | `str` | Species to search [Human, Mouse]. | `"human"` | -| `-db, --database` | `str` | Generate a motif database from a configuration file. | `"data/config.json"` | -| `-k, --best_KL` | `bool` | Find the best KL divergence only. | `False` | -| `-l, --log` | `bool` | Enable logging. | `False` | -| `-im, --immunolyser` | `bool` | Enable immunolyser output. | `False` | -| `-c, --credits` | `bool` | Show credits for the motif database pipeline. | `False` | -| `-v, --version` | `bool` | Show the version of the pipeline. | `False` | ## Example Output @@ -144,7 +143,15 @@ Example of `input` folder path: ![Example Output](assets/img/google-colab.png) +### Citation + +If you use **MHC-TP** in your research, please cite: +**Immunolyser 2.0: an advanced computational pipeline for comprehensive analysis of immunopeptidomic data** +Prithvi Raj Munday¹,†, Sanjay S.G. Krishna¹,†, Joshua Fehring¹, Nathan P. Croft¹, Anthony W. Purcell¹, Chen Li¹,², and Asolina Braun¹ +¹Department of Biochemistry and Molecular Biology and Biomedicine Discovery Institute, Monash University, Clayton, VIC, 3800, Australia +²Department of Medicine, School of Clinical Sciences at Monash Health, Monash University, Clayton, VIC 3168, Australia -More detailed instructions coming soon... 🚀 +*Computational and Structural Biotechnology Journal* +More detailed instructions coming soon... diff --git a/cli/HLAfreq.py b/cli/HLAfreq.py new file mode 100644 index 0000000..1eb4ee0 --- /dev/null +++ b/cli/HLAfreq.py @@ -0,0 +1,1009 @@ +""" +Download and combine HLA allele frequencies from multiple datasets. + +Download allele frequency data from +[allelefrequencies.net](www.allelefrequencies.net). Allele +frequencies from different populations can be combined to +estimate HLA frequencies of countries or other regions such as +global HLA frequencies. + +*** Important: *** The allele frequencies are not guaranteed to sum to 1. + +##### -------------------------------------- ########### + Modified by Sanjay SG Krishna, + Purcell/Li Lab, BDI, + Monash University, 2025. +###### -------------------------------------- ########### +More information on the methods used to combine allele frequencies +can be found in the https://github.com/BarinthusBio/HLAfreq repository. +This is modified from the original code to work to integrate with the MHC-TP project. +Thanks to the original authors for their work on this code. +""" + +from collections.abc import Iterable +from bs4 import BeautifulSoup +import requests +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import math +import scipy as sp +import matplotlib.colors as mcolors +from typing import List, Dict, Optional, Union +from urllib.parse import quote_plus +from cli.logger import CONSOLE + + + + +def simulate_population(alleles: Iterable[str], locus: str, population: str): + pop_size = np.random.randint(len(alleles), 50) + samples = np.random.choice(alleles, pop_size, replace=True) + counts = pd.Series(samples).value_counts() + counts.values / pop_size + pop = pd.DataFrame( + { + "allele": counts.index, + "loci": locus, + "population": population, + "allele_freq": counts.values / pop_size, + "sample_size": pop_size, + } + ) + return pop + + +def simulate_study(alleles, populations, locus): + study = [] + for i in range(populations): + pop = simulate_population(alleles=alleles, locus=locus, population=f"pop_{i}") + study.append(pop) + + study = pd.concat(study) + return study + +def get_list(search_by: str = "country", timeout: int = 20, format_output: str = "list") -> Union[List[str], pd.DataFrame, Dict[str, List[str]]]: + """ + Get a list of countries, ethnicities, or regions from the allelefrequencies.net website. + + All countries and ethnicities are listed at: + http://www.allelefrequencies.net/hla6006a.asp + + Args: + search_by (str): Type of data to retrieve. Options: + - "country": Get list of available countries + - "ethnic": Get list of available ethnicities + - "region": Get list of available regions + - "all": Get all available options + timeout (int): Request timeout in seconds. Defaults to 20. + format_output (str): Output format. Options: + - "list": Return as simple list(s) + - "dataframe": Return as pandas DataFrame + - "dict": Return as dictionary + + Returns: + Union[List[str], pd.DataFrame, Dict[str, List[str]]]: + Available options based on search_by parameter and format_output + + Raises: + ValueError: If search_by parameter is invalid + requests.exceptions.RequestException: If request fails + Exception: If parsing fails + + Example: + >>> countries = get_list("country") + >>> ethnicities = get_list("ethnic", format_output="dataframe") + >>> all_data = get_list("all", format_output="dict") + """ + + # Validate search_by parameter + valid_options = ["country", "ethnic", "region", "all"] + if search_by not in valid_options: + raise ValueError(f"search_by must be one of {valid_options}, got '{search_by}'") + + # Validate format_output parameter + valid_formats = ["list", "dataframe", "dict"] + if format_output not in valid_formats: + raise ValueError(f"format_output must be one of {valid_formats}, got '{format_output}'") + + base_url = "http://www.allelefrequencies.net/hla6006a.asp" + + try: + CONSOLE.log("Fetching data from allelefrequencies.net...") + response = requests.get(base_url, timeout=timeout) + response.raise_for_status() + + # Parse HTML content + soup = BeautifulSoup(response.text, 'html.parser') + + # Dictionary to store all options + all_options = {} + + # Find the form containing the select options + form = soup.find('form', {'name': 'form1'}) + if not form: + raise Exception("Could not find the main form on the webpage") + + # Extract countries + if search_by in ["country", "all"]: + country_select = form.find('select', {'name': 'hla_country'}) + if country_select: + countries = [] + for option in country_select.find_all('option'): + value = option.get('value', '').strip() + if value and value != '': # Skip empty options + countries.append(value) + all_options['country'] = sorted(countries) + CONSOLE.log(f"Found {len(countries)} countries") + + # Extract ethnicities + if search_by in ["ethnic", "all"]: + ethnic_select = form.find('select', {'name': 'hla_ethnic'}) + if ethnic_select: + ethnicities = [] + for option in ethnic_select.find_all('option'): + value = option.get('value', '').strip() + if value and value != '': # Skip empty options + ethnicities.append(value) + all_options['ethnic'] = sorted(ethnicities) + CONSOLE.log(f"Found {len(ethnicities)} ethnicities") + + # Extract regions + if search_by in ["region", "all"]: + region_select = form.find('select', {'name': 'hla_region'}) + if region_select: + regions = [] + for option in region_select.find_all('option'): + value = option.get('value', '').strip() + if value and value != '': # Skip empty options + regions.append(value) + all_options['region'] = sorted(regions) + CONSOLE.log(f"Found {len(regions)} regions") + + # Format and return results based on search_by parameter + if search_by == "all": + result_data = all_options + else: + if search_by not in all_options: + raise Exception(f"Could not extract {search_by} data from the webpage") + result_data = all_options[search_by] + + # Format output according to format_output parameter + if format_output == "list": + if search_by == "all": + return result_data # Return dictionary for "all" + else: + return result_data # Return list for specific type + + elif format_output == "dataframe": + if search_by == "all": + # Create a DataFrame with all data + max_len = max(len(v) for v in result_data.values()) + df_data = {} + for key, values in result_data.items(): + # Pad shorter lists with None + padded_values = values + [None] * (max_len - len(values)) + df_data[key] = padded_values + return pd.DataFrame(df_data) + else: + return pd.DataFrame({search_by: result_data}) + + elif format_output == "dict": + if search_by == "all": + return result_data + else: + return {search_by: result_data} + + except requests.exceptions.Timeout: + raise requests.exceptions.RequestException( + f"Request timeout after {timeout} seconds. Try increasing the timeout value." + ) + except requests.exceptions.RequestException as e: + raise requests.exceptions.RequestException(f"Failed to fetch data: {str(e)}") + except Exception as e: + raise Exception(f"Failed to parse webpage: {str(e)}") + +def url_encode_name(name: str) -> str: + """ + URL encode country/ethnic/region names for use in allelefrequencies.net URLs. + + Converts spaces to '+' and special characters to URL encoding: + - '(' becomes '%28' + - ')' becomes '%29' + - Other special characters are also encoded + + Args: + name (str): The original name + + Returns: + str: URL-encoded name ready for use in API calls + + Example: + >>> url_encode_name("Congo (Kinshasa)") + 'Congo+%28Kinshasa%29' + """ + # Use quote_plus which converts spaces to '+' and other chars to %XX + return quote_plus(name) + +def makeURL( + country="", + standard="s", + locus="", + resolution_pattern="bigger_equal_than", + resolution=2, + region="", + ethnic="", + study_type="", + dataset_source="", + sample_year="", + sample_year_pattern="", + sample_size="", + sample_size_pattern="", +): + """Create URL for search of allele frequency net database. + + All arguments are documented [here](http://www.allelefrequencies.net/extaccess.asp) + + Args: + country (str, optional): Country name to retrieve records from. Defaults to "". + standard (str, optional): Filter study quality standard to this or higher. + {'g', 's', 'a'} Gold, silver, all. Defaults to 's'. + locus (str, optional): The locus to return allele data for. Defaults to "". + resolution_pattern (str, optional): Resolution comparitor {'equal', 'different', + 'less_than', 'bigger_than', 'less_equal_than', 'bigger_equal_than'}. + Filter created using `resolution` and `resolution_pattern`. + Defaults to "bigger_equal_than". + resolution (int, optional): Number of fields of resolution of allele. Filter + created using `resolution` and `resolution_pattern`. Defaults to 2. + region (str, optional): Filter to geographic region. {Asia, Australia, + Eastern Europe, ...}. + All regions listed [here](http://www.allelefrequencies.net/pop6003a.asp). + Defaults to "". + ethnic (str, optional): Filter to ethnicity. {"Amerindian", "Black", "Caucasian", ...}. + All ethnicities listed [here](http://www.allelefrequencies.net/pop6003a.asp). + Defaults to "". + study_type (str, optional): Type of study. {"Anthropology", "Blood+Donor", + "Bone+Marrow+Registry", "Controls+for+Disease+Study", "Disease+Study+Patients", + "Other", "Solid+Organd+Unrelated+Donors", "Stem+cell+donors"}. Defaults to "". + dataset_source (str, optional): Source of data. {"Literature", + "Proceedings+of+IHWs", "Unpublished"}. Defaults to "". + sample_year (int, optional): Sample year to compare to. Filter created using + sample_year and sample_year_pattern. Defaults to "". + sample_year_pattern (str, optional): Pattern to compare sample year to. Filter + created using sample_year and sample_year_pattern. {'equal', 'different', + 'less_than', 'bigger_than', 'less_equal_than', 'bigger_equal_than'}. + Defaults to "". + sample_size (int, optional): Sample size to compare to. Filter created using + sample_size and sample_size_pattern. Defaults to "". + sample_size_pattern (str, optional): Pattern to compare sample size to. Filter + created using sample_size and sample_size_pattern. {'equal', 'different', + 'less_than', 'bigger_than', 'less_equal_than', 'bigger_equal_than'}. + Defaults to "". + + Returns: + str: URL to search allelefrequencies.net + """ + base = "http://www.allelefrequencies.net/hla6006a.asp?" + locus_type = "hla_locus_type=Classical&" + hla_locus = "hla_locus=%s&" % (locus) + country = "hla_country=%s&" % (country) + region = "hla_region=%s&" % (region) + ethnic = "hla_ethnic=%s&" % (ethnic) + study_type = "hla_study=%s&" % (study_type) + dataset_source = "hla_dataset_source=%s&" % (dataset_source) + sample_year = "hla_sample_year=%s&" % (sample_year) + sample_year_pattern = "hla_sample_year_pattern=%s&" % (sample_year_pattern) + sample_size = "hla_sample_size=%s&" % (sample_size) + sample_size_pattern = "hla_sample_size_pattern=%s&" % (sample_size_pattern) + hla_level_pattern = "hla_level_pattern=%s&" % (resolution_pattern) + hla_level = "hla_level=%s&" % (resolution) + standard = "standard=%s&" % standard + url = ( + base + + locus_type + + hla_locus + + country + + hla_level_pattern + + hla_level + + standard + + region + + ethnic + + study_type + + dataset_source + + sample_year + + sample_year_pattern + + sample_size + + sample_size_pattern + ) + return url + + +def parseAF(bs): + """Generate a dataframe from a given html page + + Args: + bs (bs4.BeautifulSoup): BeautifulSoup object from allelefrequencies.net page + + Returns: + pd.DataFrame: Table of allele, allele frequency, samplesize, and population + """ + # Get the results table from the div `divGenDetail` + tab = bs.find("div", {"id": "divGenDetail"}).find("table", {"class": "tblNormal"}) + # Get the column headers from the first row of the table + columns = [ + "line", + "allele", + "flag", + "population", + "carriers%", + "allele_freq", + "AF_graphic", + "sample_size", + "database", + "distribution", + "haplotype_association", + "notes", + ] + rows = [] + for row in tab.find_all("tr"): + rows.append([td.get_text(strip=True) for td in row.find_all("td")]) + # Make dataframe of table rows + # skip the first row as it's `th` headers + df = pd.DataFrame(rows[1:], columns=columns) + + # Get HLA loci + df["loci"] = df.allele.apply(lambda x: x.split("*")[0]) + + # Drop unwanted columns + df = df[["allele", "loci", "population", "allele_freq", "carriers%", "sample_size"]] + return df + + +def Npages(bs): + """How many pages of results are there? + + Args: + bs (bs4.BeautifulSoup): BS object of allelefrequencies.net results page + + Returns: + int: Total number of results pages + """ + # Get the table with number of pages + navtab = bs.find("div", {"id": "divGenNavig"}).find("table", {"class": "table10"}) + if not navtab: + raise AssertionError( + "navtab does not evaluate to True. Check URL returns results in web browser." + ) + # Get cell with ' of ' in + pagesOfN = [ + td.get_text(strip=True) for td in navtab.find_all("td") if " of " in td.text + ] + # Check single cell returned + if not len(pagesOfN) == 1: + raise AssertionError("divGenNavig should contain 1 of not %s" % len(pagesOfN)) + # Get total number of pages + N = pagesOfN[0].split("of ")[1] + N = int(N) + return N + + +def formatAF(AFtab, ignoreG=True): + """Format allele frequency table. + + Convert sample_size and allele_freq to numeric data type. + Removes commas from sample size. Removes "(*)" from allele frequency if + `ignoreG` is `True`. `formatAF()` is used internally by combineAF and getAFdata + by default. + + Args: + AFtab (pd.DataFrame): Allele frequency data downloaded from allelefrequency.net + using `getAFdata()`. + ignoreG (bool, optional): Treat G group alleles as normal. + See http://hla.alleles.org/alleles/g_groups.html for details. Defaults to True. + + Returns: + pd.DataFrame: The formatted allele frequency data. + """ + df = AFtab.copy() + if df.sample_size.dtype == "O": + df.sample_size = pd.to_numeric(df.sample_size.str.replace(",", "")) + if df.allele_freq.dtype == "O": + if ignoreG: + df.allele_freq = df.allele_freq.str.replace("(*)", "", regex=False) + df.allele_freq = pd.to_numeric(df.allele_freq) + return df + + +def getAFdata(base_url, timeout=20, format=True, ignoreG=True): + """Get all allele frequency data from a search base_url. + + Iterates over all pages regardless of which page is based. + + Args: + base_url (str): URL for base search. + timeout (int): How long to wait to receive a response. + format (bool): Format the downloaded data using `formatAF()`. + ignoreG (bool): treat allele G groups as normal. + See http://hla.alleles.org/alleles/g_groups.html for details. Default = True + + Returns: + pd.DataFrame: allele frequency data parsed into a pandas dataframe + """ + # Get BS object from base search + try: + bs = BeautifulSoup(requests.get(base_url, timeout=timeout).text, "html.parser") + except requests.exceptions.ReadTimeout as e: + raise Exception( + "Requests timeout, try a larger `timeout` value for `getAFdata()`" + ) from None + # How many pages of results + N = Npages(bs) + CONSOLE.log("%s pages of results" % N) + # iterate over pages, parse and combine data from each + tabs = [] + for i in range(N): + # print (" Parsing page %s" %(i+1)) + CONSOLE.log("Parsing page %s" % (i + 1), end="\r") + url = base_url + "page=" + str(i + 1) + try: + bs = BeautifulSoup(requests.get(url, timeout=timeout).text, "html.parser") + except requests.exceptions.ReadTimeout as e: + raise Exception( + "Requests timeout, try a larger `timeout` value for `getAFdata()`" + ) from None + tab = parseAF(bs) + tabs.append(tab) + CONSOLE.log("Download complete") + tabs = pd.concat(tabs) + if format: + try: + tabs = formatAF(tabs, ignoreG) + except AttributeError: + CONSOLE.log("Formatting failed, non-numeric datatypes may remain.") + return tabs + + +def incomplete_studies(AFtab, llimit=0.95, ulimit=1.1, datasetID="population"): + """Report any studies with allele freqs that don't sum to 1 + + Args: + AFtab (pd.DataFrame): Dataframe containing multiple studies + llimit (float, optional): Lower allele_freq sum limit that counts as complete. + Defaults to 0.95. + ulimit (float, optional): Upper allele_freq sum limit that will not be reported. + Defaults to 1.1. + datasetID (str): Unique identifier column for study + """ + poplocs = AFtab.groupby([datasetID, "loci"]).allele_freq.sum() + lmask = poplocs < llimit + if sum(lmask > 0): + CONSOLE.log(poplocs[lmask]) + CONSOLE.log(f"{sum(lmask)} studies have total allele frequency < {llimit}") + umask = poplocs > ulimit + if sum(umask > 0): + CONSOLE.log(poplocs[umask]) + CONSOLE.log(f"{sum(umask)} studies have total allele frequency > {ulimit}") + incomplete = pd.concat([poplocs[lmask], poplocs[umask]]) + return incomplete + + +def only_complete(AFtab, llimit=0.95, ulimit=1.1, datasetID="population"): + """Returns only complete studies. + + Data is dropped if the locus for that population is not complete, i.e. doesn't + sum to between `llimit` and `ulimit`. This prevents throwing away data if + another loci in the population is incomplete. + + Args: + AFtab (pd.DataFrame): Dataframe containing multiple studies + llimit (float, optional): Lower allele_freq sum limit that counts as complete. + Defaults to 0.95. + ulimit (float, optional): Upper allele_freq sum limit that will not be reported. + Defaults to 1.1. + datasetID (str): Unique identifier column for study. Defaults to 'population'. + + Returns: + pd.DataFrame: Allele frequency data of multiple studies, but only complete studies. + """ + noncomplete = incomplete_studies( + AFtab=AFtab, llimit=llimit, ulimit=ulimit, datasetID=datasetID + ) + # Returns False if population AND loci are in the noncomplete.index + # AS A PAIR + # This is important so that we don't throw away all data on a population + # just because one loci is incomplete. + complete_mask = AFtab.apply( + lambda x: (x[datasetID], x.loci) not in noncomplete.index, axis=1 + ) + df = AFtab[complete_mask] + return df + + +def check_resolution(AFtab): + """Check if all alleles in AFtab have the same resolution. + Will print the number of records with each resolution. + + Args: + AFtab (pd.DataFrame): Allele frequency data + + Returns: + bool: True only if all alleles have the same resolution, else False. + """ + resolution = 1 + AFtab.allele.str.count(":") + resVC = resolution.value_counts() + pass_check = len(resVC) == 1 + if not pass_check: + CONSOLE.log(resVC) + CONSOLE.log("Multiple resolutions in AFtab. Fix with decrease_resolution()") + return pass_check + + +def decrease_resolution(AFtab, newres, datasetID="population"): + """Decrease allele resolution so all alleles have the same resolution. + + Args: + AFtab (pd.DataFrame): Allele frequency data. + newres (int): The desired number of fields for resolution. + datasetID (str, optional): Column to use as stud identifier. + Defaults to 'population'. + + Returns: + pd.DataFrame: Allele frequency data with all alleles of requested resolution. + """ + df = AFtab.copy() + resolution = 1 + df.allele.str.count(":") + if not all(resolution >= newres): + raise AssertionError(f"Some alleles have resolution below {newres} fields") + new_allele = df.allele.str.split(":").apply(lambda x: ":".join(x[:newres])) + df.allele = new_allele + collapsed = collapse_reduced_alleles(df, datasetID=datasetID) + return collapsed + + +def collapse_reduced_alleles(AFtab, datasetID="population"): + df = AFtab.copy() + # Group by alleles within datasets + grouped = df.groupby([datasetID, "allele"]) + # Sum allele freq but keep other columns + collapsed = grouped.apply( + lambda row: [ + sum(row.allele_freq), + row.sample_size.unique()[0], + row.loci.unique()[0], + len(row.loci.unique()), + len(row.sample_size.unique()), + ] + ) + collapsed = pd.DataFrame( + collapsed.tolist(), + index=collapsed.index, + columns=["allele_freq", "sample_size", "loci", "#loci", "#sample_sizes"], + ).reset_index() + # Within a study each all identical alleles should have the same loci and sample size + if not all(collapsed["#loci"] == 1): + raise AssertionError( + "Multiple loci found for a single allele in a single population" + ) + if not all(collapsed["#sample_sizes"] == 1): + raise AssertionError( + "Multiple sample_sizes found for a single allele in a single population" + ) + collapsed = collapsed[ + ["allele", "loci", "population", "allele_freq", "sample_size"] + ] + alleles_unique_in_study(collapsed) + return collapsed + + +def unmeasured_alleles(AFtab, datasetID="population"): + """When combining AF estimates, unreported alleles can inflate frequencies + so AF sums to >1. Therefore we add unreported alleles with frequency zero. + + Args: + AFtab (pd.DataFrame): Formatted allele frequency data + datasetID (str): Unique identifier column for study + + Returns: + pd.DataFrame: Allele frequency data with all locus alleles reported + for each dataset + """ + df = AFtab.copy() + loci = df.loci.unique() + # Iterate over loci separately + for locus in loci: + # Iterate over each dataset reporting that locus + datasets = df[df.loci == locus][datasetID].unique() + for dataset in datasets: + # Single locus, single dataset + datasetAF = df[(df[datasetID] == dataset) & (df.loci == locus)] + # What was the sample size for this data? + dataset_sample_size = datasetAF.sample_size.unique() + if not (len(dataset_sample_size) == 1): + raise AssertionError( + "dataset_sample_size must be 1, not %s" % len(dataset_sample_size) + ) + dataset_sample_size = dataset_sample_size[0] + # Get all alleles for this locus (across datasets) + ualleles = df[df.loci == locus].allele.unique() + # Which of these alleles are not in this dataset? + missing_alleles = [ + allele for allele in ualleles if not allele in datasetAF.allele.values + ] + missing_rows = [ + (al, locus, dataset, 0, 0, dataset_sample_size) + for al in missing_alleles + ] + missing_rows = pd.DataFrame( + missing_rows, + columns=[ + "allele", + "loci", + datasetID, + "allele_freq", + "carriers%", + "sample_size", + ], + ) + # Add them in with zero frequency + if not missing_rows.empty: + df = pd.concat([df, missing_rows], ignore_index=True) + return df + + +def combineAF( + AFtab, + weights="2n", + alpha=[], + datasetID="population", + format=True, + ignoreG=True, + add_unmeasured=True, + complete=True, + resolution=True, + unique=True, +): + """Combine allele frequencies from multiple studies. + + `datasetID` is the unique identifier for studies to combine. + Allele frequencies combined using a Dirichlet distribution where each study's + contribution to the concentration parameter is $2 * sample_size * allele_frequency$. + Sample size is doubled to get `2n` due to diploidy. If an alternative `weights` is + set it is not doubled. The total concentration parameter of the Dirichlet distribution + is the contributions from all studies plus the prior `alpha`. If `alpha` is not set + the prior defaults to 1 observation of each allele. + + Args: + AFtab (pd.DataFrame): Table of Allele frequency data + weights (str, optional): Column to be weighted by allele frequency to generate + concentration parameter of Dirichlet distribution. Defaults to '2n'. + alpha (list, optional): Prior to use for Dirichlet distribution. Defaults to []. + datasetID (str, optional): Unique identifier column for study. Defaults to + 'population'. + format (bool, optional): Run `formatAF()`. Defaults to True. + ignoreG (bool, optional): Treat allele G groups as normal, see `formatAF()`. + Defaults to True. + add_unmeasured (bool, optional): Add unmeasured alleles to each study. This is + important to ensure combined allele frequencies sum to 1. See + `add_unmeasured()`. Defaults to True. + complete (bool, optional): Check study completeness. Uses default values for + `incomplete_studies()`. If you are happy with your study completeness can + be switched off with False. Defaults to True. + resolution (bool, optional): Check that all alleles have the same resolution, + see `check_resolution()`. Defaults to True. + unique (bool, optional): Check that each allele appears no more than once per + study. See `alleles_unique_in_study()`. Defaults to True. + + Returns: + pd.DataFrame: Allele frequencies after combining estimates from all studies. + *allele_freq* is the combined frequency estimate from the Dirichlet mean + where the concentration is `alpha` + `c`. + *alpha* is the prior used for the Dirichlet distribution. + *c* is the observations used for the Dirichlet distribution. + *sample_size* is the total sample size of all combined studies. + *wav* is the weighted average. + """ + df = AFtab.copy() + single_loci(df) + if unique: + if not alleles_unique_in_study(df, datasetID=datasetID): + raise AssertionError("The same allele appears multiple times in a dataset") + if complete: + if not incomplete_studies(df, datasetID=datasetID).empty: + raise AssertionError( + "AFtab contains studies with AF that doesn't sum to 1. Check incomplete_studies(AFtab)" + ) + if resolution: + if not check_resolution(df): + raise AssertionError( + "AFtab conains alleles at multiple resolutions, check check_resolution(AFtab)" + ) + if format: + df = formatAF(df, ignoreG) + if add_unmeasured: + df = unmeasured_alleles(df, datasetID) + try: + df["2n"] = df.sample_size * 2 + except AttributeError: + CONSOLE.log("column '2n' could not be created") + df["c"] = df.allele_freq * df[weights] + grouped = df.groupby("allele", sort=True) + combined = grouped.apply( + lambda row: [ + row.name, + row.loci.unique()[0], + np.average(row.allele_freq, weights=row[weights]), + row.c.sum(), + row.sample_size.sum(), + ] + ) + combined = pd.DataFrame( + combined.tolist(), columns=["allele", "loci", "wav", "c", "sample_size"] + ) + combined = combined.reset_index(drop=True) + # Check that all alleles in a locus have the same sample size + # after merging + if duplicated_sample_size(combined): + id_duplicated_allele(grouped) + if not alpha: + alpha = default_prior(len(combined.allele)) + combined["alpha"] = alpha + # Calculate Dirichlet mean for each allele + combined["allele_freq"] = sp.stats.dirichlet(combined.alpha + combined.c).mean() + + return combined + + +def default_prior(k): + """Calculate a default prior, 1 observation of each class. + + Args: + k (int): Number of classes in the Dirichlet distribution. + + Returns: + list: List of k 1s to use as prior. + """ + alpha = [1] * k + return alpha + + +def single_loci(AFtab): + """Check that allele frequency data is only of one locus + + Args: + AFtab (pd.DataFrame): Allele frequency data + """ + if not len(AFtab.loci.unique()) == 1: + raise AssertionError("'AFtab' must contain only 1 loci") + + +def alleles_unique_in_study(AFtab, datasetID="population"): + """Are all alleles unique in each study? + + Checks that no alleles are reported more than once in a single study. + Study is defined by `datasetID`. + + Args: + AFtab (pd.DataFrame): Allele frequency data + datasetID (str, optional): Unique identifier column to define study. + Defaults to 'population'. + + Returns: + bool: `True` on if no alleles occur more than once in any study, otherwise `False`. + """ + df = AFtab.copy() + grouped = df.groupby([datasetID, "allele"]) + # Are allele alleles unique? i.e. do any occur multiple times in grouping? + unique = grouped.size()[grouped.size() > 1].empty + if not unique: + CONSOLE.log(f"Non unique alleles in study, is datasetID correct? {datasetID}") + CONSOLE.log(grouped.size()[grouped.size() > 1]) + return unique + + +def duplicated_sample_size(AFtab): + """Returns True if any loci has more than 1 unique sample size""" + locus_sample_sizes = AFtab.groupby("loci").sample_size.apply( + lambda x: len(x.unique()) + ) + return any(locus_sample_sizes != 1) + + +def id_duplicated_allele(grouped): + """Reports the allele that has mupltiple sample sizes""" + duplicated_population = grouped.population.apply(lambda x: any(x.duplicated())) + if not all(~duplicated_population): + raise AssertionError( + f"duplicated population within allele {duplicated_population[duplicated_population].index.tolist()}" + ) + + +def population_coverage(p): + """Proportion of people with at least 1 copy of this allele assuming HWE. + + Args: + p (float): Allele frequency + + Returns: + float: Sum of homozygotes and heterozygotes for this allele + """ + q = 1 - p + homo = p**2 + hetero = 2 * p * q + return homo + hetero + + +def betaAB(alpha): + """Calculate `a` `b` values for all composite beta distributions. + + Given the `alpha` vector defining a Dirichlet distribution calculate the `a` `b` values + for all composite beta distributions. + + Args: + alpha (list): Values defining a Dirichlet distribution. This will be the prior + (for a naive distribution) or the prior + caf.c for a posterior distribution. + + Returns: + list: List of `a` `b` values defining beta values, i.e. for each allele it is + the number of times it was and wasn't observed. + """ + ab = [(a, sum(alpha) - a) for a in alpha] + return ab + + +# def betaCI(a,b,credible_interval=0.95): +# """Calculat the central credible interval of a beta distribution + +# Args: +# a (float): Beta shape parameter `a`, i.e. the number of times the allele was observed. +# b (float): Beta shape parameter `b`, i.e. the number of times the allele was not observed. +# credible_interval (float, optional): The size of the credible interval requested. Defaults to 0.95. + +# Returns: +# tuple: Lower and upper credible interval of beta distribution. +# """ +# bd = sp.stats.beta(a,b) +# lower_quantile = (1-credible_interval)/2 +# upper_quantile = 1-lower_quantile +# lower_interval = bd.ppf(lower_quantile) +# upper_interval = bd.ppf(upper_quantile) +# return lower_interval, upper_interval + +# def AFci(caf, credible_interval=0.95): +# """Calculate credible interval for combined allele frequency table. +# Note that this ignores sampling error so confidence interval is too tight. +# Use HLAhdi.AFhdi() instead. + +# Args: +# caf (pd.DataFrame): Table produced by combineAF() +# credible_interval (float, optional): The desired confidence interval. Defaults to 0.95. + +# Returns: +# list: Lower and upper credible intervals as a list of tuples +# """ +# ab = betaAB( +# caf.alpha + caf.c, +# ) +# ci = [betaCI(a, b, credible_interval) for a,b in ab] +# return ci + + +def plot_prior(concentration, ncol=2, psteps=1000, labels=""): + """Plot probability density function for prior values. + + Args: + concentration (list): Vector of the prior Dirichlet concentration values. + ncol (int, optional): Number of columns. Defaults to 2. + labels (list, optional): Labels for elements of concentration in the same + order. Defaults to "". + """ + ab = betaAB(concentration) + pline = np.linspace(0, 1, psteps) + nrow = math.ceil(len(concentration) / ncol) + fig, ax = plt.subplots(nrow, ncol, sharex=True) + fig.suptitle("Probability density") + # If labels is a list nothing happens, + # But if it's a series it converts to a list + labels = list(labels) + if not labels: + labels = [""] * len(concentration) + if not len(concentration) == len(labels): + raise AssertionError("concentration must be same length as labels") + for i, alpha in enumerate(concentration): + a, b = ab[i] + bd = sp.stats.beta(a, b) + pdf = [bd.pdf(p) for p in pline] + ax.flat[i].plot(pline, pdf) + ax.flat[i].set_title(labels[i]) + for axi in ax[-1, :]: + axi.set(xlabel="Allele freq") + for axi in ax[:, 0]: + axi.set(ylabel="PDF") + plt.show() + + +def plotAF( + caf=pd.DataFrame(), + AFtab=pd.DataFrame(), + cols=list(mcolors.TABLEAU_COLORS.keys()), + datasetID="population", + hdi=pd.DataFrame(), + compound_mean=pd.DataFrame(), +): + """Plot allele frequency results from `HLAfreq`. + + Plot combined allele frequencies, individual allele frequencies, + and credible intervals on combined allele frequency estimates. + Credible interval is only plotted if a value is given for `hdi`. + The plotted Credible interval is whatever was passed to HLAfreq_pymc.AFhdi() + when calculating hdi. + + Args: + caf (pd.DataFrame, optional): Combined allele frequency estimates from + HLAfreq.combineAF. Defaults to pd.DataFrame(). + AFtab (pd.DataFrame, optional): Table of allele frequency data. Defaults + to pd.DataFrame(). + cols (list, optional): List of colours to use for each individual dataset. + Defaults to list(mcolors.TABLEAU_COLORS.keys()). + datasetID (str, optional): Column used to define separate datasets. Defaults + to "population". + weights (str, optional): Column to be weighted by allele frequency to generate + concentration parameter of Dirichlet distribution. Defaults to '2n'. + hdi (pd.DataFrame, optional): The high density interval object to plot credible + intervals. Produced by HLAfreq.HLA_pymc.AFhdi(). Defaults to pd.DataFrame(). + compound_mean (pd.DataFrame, optional): The high density interval object to plot + post_mean. Produced by HLAfreq.HLA_pymc.AFhdi(). Defaults to pd.DataFrame(). + """ + # Plot allele frequency for each dataset + if not AFtab.empty: + # Cols must be longer than the list of alleles + # If not, repeat the list of cols + repeat_cols = np.ceil(len(AFtab[datasetID]) / len(cols)) + repeat_cols = int(repeat_cols) + cols = cols * repeat_cols + # Make a dictionary mapping datasetID to colours + cmap = dict(zip(AFtab[datasetID].unique(), cols)) + plt.scatter( + x=AFtab.allele_freq, + y=AFtab.allele, + c=[cmap[x] for x in AFtab[datasetID]], + alpha=0.7, + zorder=2, + ) + # Plot combined allele frequency + if not caf.empty: + plt.scatter( + x=caf.allele_freq, + y=caf.allele, + edgecolors="black", + facecolors="none", + zorder=3, + ) + # Plot high density interval + if not hdi.empty: + # assert not AFtab.empty, "AFtab is needed to calculate credible interval" + # from HLAfreq import HLAfreq_pymc as HLAhdi + # print("Fitting model with PyMC, make take a few seconds") + # hdi = HLAhdi.AFhdi( + # AFtab=AFtab, + # weights=weights, + # datasetID=datasetID, + # credible_interval=credible_interval, + # conc_mu=conc_mu, + # conc_sigma=conc_sigma + # ) + for interval in hdi.iterrows(): + # .iterrows returns a index and data as a tuple for each row + plt.hlines( + y=interval[1]["allele"], + xmin=interval[1]["lo"], + xmax=interval[1]["hi"], + color="black", + ) + if not compound_mean.empty: + for row in compound_mean.iterrows(): + plt.scatter( + y=row[1]["allele"], x=row[1]["post_mean"], color="black", marker="|" + ) + plt.xlabel("Allele frequency") + plt.grid(zorder=0) + plt.show() \ No newline at end of file diff --git a/cli/NumbaSearch.py b/cli/NumbaSearch.py new file mode 100644 index 0000000..252a308 --- /dev/null +++ b/cli/NumbaSearch.py @@ -0,0 +1,407 @@ +import numpy as np +import pandas as pd +from numba import jit, prange, types +from numba.typed import Dict, List +import os +import pickle +import mmap +from pathlib import Path +import time +from typing import Tuple, Optional, Dict as PyDict, List as PyList +import warnings +warnings.filterwarnings('ignore') +from cli.logger import * + +class npClusterSearch: + """optimized cluster search using Numba JIT compilation""" + + def __init__(self, cache_dir: str = "numba_cache"): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True) + self.reference_matrices = None + self.reference_metadata = None + self.is_initialized = False + self.console = CONSOLE + + def build_reference_cache(self, db: pd.DataFrame, species: str, + data_dir: str, force_rebuild: bool = False): + """ + Build fast reference cache using memory-mapped arrays + + Args: + db: Database with matrix paths + species: Species name + data_dir: Directory containing matrix files + force_rebuild: Force cache rebuild + """ + cache_file = self.cache_dir / f"ref_matrices_{species}.npy" + metadata_file = self.cache_dir / f"ref_metadata_{species}.pkl" + + if not force_rebuild and cache_file.exists() and metadata_file.exists(): + self.console.log("Loading existing fast NP array cache...") + self._load_cache(cache_file, metadata_file) + return + + self.console.log("Building fast reference cache...") + start_time = time.time() + + # Standard amino acid order + amino_acids = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', + 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] + + self.amino_acids = amino_acids + # Load all matrices + matrices_data = [] + metadata_list = [] + + matrix_paths = db['matrices_path'].unique() + + for i, mat_path in enumerate(matrix_paths): + if i % 1000 == 0: + self.console.log(f"Processed {i}/{len(matrix_paths)} matrices...") + + full_path = os.path.join(data_dir, mat_path) + matrix = self._fast_load_matrix(full_path, amino_acids) + + if matrix is not None: + matrices_data.append(matrix) + + # Extract HLA info + if species.lower() == 'human': + hla = os.path.basename(mat_path).replace('.txt', '').split('_')[1] if '_' in mat_path else os.path.basename(mat_path).replace('.txt', '') + else: + hla = os.path.basename(mat_path).replace('.txt', '') + + metadata_list.append({ + 'index': len(matrices_data) - 1, + 'path': mat_path, + 'hla': hla, + 'original_db_index': i + }) + + + # Find max dimensions and create unified array + if not matrices_data: + raise ValueError("No valid matrices found!") + # self.console.log(f"No valid matrices found!") + + + + max_positions = max(m.shape[0] for m in matrices_data) + n_matrices = len(matrices_data) + n_amino_acids = len(amino_acids) + + # Create memory-efficient unified array + unified_matrices = np.zeros((n_matrices, max_positions, n_amino_acids), dtype=np.float32) + + for i, matrix in enumerate(matrices_data): + # Zero-pad shorter matrices + unified_matrices[i, :matrix.shape[0], :matrix.shape[1]] = matrix + + # Save with memory mapping for fast loading + np.save(cache_file, unified_matrices) + + # Save metadata + metadata_df = pd.DataFrame(metadata_list) + with open(metadata_file, 'wb') as f: + pickle.dump({ + 'metadata': metadata_df, + 'amino_acids': amino_acids, + 'max_positions': max_positions, + 'species': species + }, f) + + self.reference_matrices = unified_matrices + self.reference_metadata = metadata_df + self.is_initialized = True + + build_time = time.time() - start_time + self.console.log(f"Cache built in {build_time:.2f}s with {n_matrices} matrices ({max_positions} positions)") + + def _load_cache(self, cache_file: Path, metadata_file: Path): + """Load pre-built cache with memory mapping for instant access""" + # Memory-mapped loading for zero-copy access + self.reference_matrices = np.load(cache_file, mmap_mode='r') + + with open(metadata_file, 'rb') as f: + cache_data = pickle.load(f) + self.reference_metadata = cache_data['metadata'] + self.amino_acids = cache_data['amino_acids'] + self.max_positions = cache_data['max_positions'] + + self.is_initialized = True + self.console.log(f"fast cache loaded: {len(self.reference_matrices)} matrices") + + def _fast_load_matrix(self, file_path: str, amino_acids: PyList[str]) -> Optional[np.ndarray]: + """Optimized matrix loading""" + try: + if not os.path.exists(file_path): + return None + + # Fast file reading with minimal parsing + with open(file_path, 'r') as f: + lines = f.readlines() + + data_rows = [] + n_aa = len(amino_acids) + + for line in lines: + line = line.strip() + if not line or line.startswith('#') or 'A R N D' in line: + continue + + parts = line.split() + if len(parts) >= n_aa: + try: + # Take last n_aa values (amino acid scores) + values = [float(parts[-(n_aa-i)]) for i in range(n_aa)] + data_rows.append(values) + except (ValueError, IndexError): + continue + + return np.array(data_rows, dtype=np.float32) if data_rows else None + + except Exception: + return None + + def Np_fast_search(self, gibbs_matrices_dir: str, n_clusters: str = "all", + hla_filter: PyList[str] = None, threshold: float = 0.70, + topHit: int = 3): + """ + Fast correlation search using Numba JIT compilation + + Returns: + -------- + results : dict + Best hit per Gibbs matrix (compatible with existing code) + top_hits_dict : dict + Top N hits per Gibbs matrix (hla -> correlation) + """ + if not self.is_initialized: + raise ValueError("Cache not initialized. Call build_reference_cache first.") + + self.console.log("Starting NP search...") + start_time = time.time() + + # Load Gibbs matrices + gibbs_files = [f for f in os.listdir(gibbs_matrices_dir) if f.endswith('.mat')] + if n_clusters.isdigit(): + gibbs_files = [f for f in gibbs_files if f.endswith(f"of{n_clusters}.mat")] + + self.console.log(f"Searching {len(gibbs_files)} Gibbs matrices against {len(self.reference_matrices)} references...") + + gibbs_matrices_list = [] + gibbs_names = [] + for gf in gibbs_files: + self.console.log(f"Loading {gf}...") + matrix = self._fast_load_matrix(os.path.join(gibbs_matrices_dir, gf), self.amino_acids) + if matrix is not None: + padded = np.zeros((self.max_positions, len(self.amino_acids)), dtype=np.float32) + padded[:matrix.shape[0], :matrix.shape[1]] = matrix + gibbs_matrices_list.append(padded) + gibbs_names.append(gf) + self.console.log(f"Loaded {len(gibbs_matrices_list)} Gibbs matrices.") + if not gibbs_matrices_list: + return {}, {} + + gibbs_matrices = np.array(gibbs_matrices_list, dtype=np.float32) + + # HLA filter mask + hla_mask = np.ones(len(self.reference_metadata), dtype=np.bool_) + if hla_filter: + hla_mask = self.reference_metadata['hla'].isin(hla_filter).values + + correlation_matrix, invalid_flags = compute_all_correlations_jit( + gibbs_matrices, + self.reference_matrices.astype(np.float32), + hla_mask, + threshold + ) + + for i, flag in enumerate(invalid_flags): + if flag == 1: + self.console.log(f"Gibbs matrix {i} was skipped (invalid or insufficient data).") + + results = {} + top_hits_dict = {} + + for i, gibbs_name in enumerate(gibbs_names): + corrs = correlation_matrix[i, :] + top_indices = np.argsort(corrs)[::-1] # descending order + top_indices = [idx for idx in top_indices if corrs[idx] >= threshold][:topHit] + + if not top_indices: + continue + + # Best hit for compatibility + best_idx = top_indices[0] + best_hit = { + 'hla': self.reference_metadata.iloc[best_idx]['hla'], + 'correlation': corrs[best_idx], + 'ref_path': self.reference_metadata.iloc[best_idx]['path'] + } + results[gibbs_name] = best_hit + + # Top N hits as dict + top_hits_dict[gibbs_name] = [ + { + 'hla': self.reference_metadata.iloc[idx]['hla'], + 'correlation': corrs[idx], + 'ref_path': self.reference_metadata.iloc[idx]['path'] + } + for idx in top_indices + ] + + search_time = time.time() - start_time + self.console.log(f"NP search completed in {search_time:.3f} seconds!") + self.console.log(f"Found {len(results)} matches above threshold {threshold}") + + return results, top_hits_dict + + + +# Numba JIT-compiled functions for maximum speed +@jit(nopython=True, parallel=True, fastmath=True) +def compute_all_correlations_jit(gibbs_matrices, ref_matrices, hla_mask, threshold): + """ + JIT-compiled correlation computation - THIS IS THE SPEED SECRET! Hope it works!! + + Computes all correlations in parallel using optimized machine code + """ + n_gibbs = gibbs_matrices.shape[0] + n_refs = ref_matrices.shape[0] + + # Pre-allocate result matrix + correlations = np.full((n_gibbs, n_refs), -1.0, dtype=np.float32) + + #recored if any invalid matrix + invalid_flags = np.zeros(n_gibbs, dtype=np.int32) # 0 = valid, 1 = invalid + # Parallel computation across Gibbs matrices + for i in prange(n_gibbs): + gibbs_flat = gibbs_matrices[i].flatten() + + # Remove NaN and zero values once + gibbs_valid_mask = ~(np.isnan(gibbs_flat) | (gibbs_flat == 0.0)) + gibbs_clean = gibbs_flat[gibbs_valid_mask] + + if len(gibbs_clean) < 10: # Skip if too few valid values + invalid_flags[i] = 1 + continue + + # Precompute statistics for Gibbs matrix + gibbs_mean = np.mean(gibbs_clean) + gibbs_std = np.std(gibbs_clean) + + if gibbs_std == 0.0: + invalid_flags[i] = 1 + continue + + # Compute correlations with all reference matrices + for j in range(n_refs): + if not hla_mask[j]: # Skip filtered HLA types + continue + + ref_flat = ref_matrices[j].flatten() + ref_clean = ref_flat[gibbs_valid_mask] # Use same mask + + # Quick statistics + ref_mean = np.mean(ref_clean) + ref_std = np.std(ref_clean) + + if ref_std == 0.0: + continue + + # Fast Pearson correlation + numerator = np.mean((gibbs_clean - gibbs_mean) * (ref_clean - ref_mean)) + correlation = numerator / (gibbs_std * ref_std) + + # Only store if above threshold (saves memory) + if correlation >= threshold: + correlations[i, j] = correlation + + return correlations, invalid_flags + +@jit(nopython=True, fastmath=True) +def np_pearson_correlation(x, y): + """fast Pearson correlation for small arrays""" + n = len(x) + if n < 2: + return 0.0 + + sum_x = np.sum(x) + sum_y = np.sum(y) + sum_xx = np.sum(x * x) + sum_yy = np.sum(y * y) + sum_xy = np.sum(x * y) + + numerator = n * sum_xy - sum_x * sum_y + denominator = np.sqrt((n * sum_xx - sum_x * sum_x) * (n * sum_yy - sum_y * sum_y)) + + return numerator / denominator if denominator != 0.0 else 0.0 + +class NP_clusterSearchCLI: + """Wrapper for NP search with easy integration""" + + def __init__(self, cache_dir: str = "numba_cache"): + self.np_fast = npClusterSearch(cache_dir=cache_dir) + self.correlation_dict = {} + self.console = CONSOLE + + def compute_correlations_V3(self, db: pd.DataFrame, gibbs_results: str, + n_clusters: str, output_path: str, + hla_list: PyList[str] = None, + threshold: float = 0.70, + data_dir: str = None, + species: str = "human", + topHit: int = 3): + """fast correlation computation - 10-100x faster than original""" + + self.console.log("Starting Numba Arrary correlation computation...") + + # Build/load cache + cache_start = time.time() + self.np_fast.build_reference_cache(db, species, data_dir) + cache_time = time.time() - cache_start + self.console.log(f"Cache ready in {cache_time:.2f}s") + + #fast search + gibbs_matrices_dir = os.path.join(gibbs_results, "matrices") + results, top_hits_dict = self.np_fast.Np_fast_search( + gibbs_matrices_dir, n_clusters, hla_list, threshold, topHit + ) + + # Convert to your existing format + # for gibbs_name, result in results.items(): + # key = (gibbs_name, result['ref_path']) + # self.correlation_dict[key] = result['correlation'] + + for gibbs_name, hits in top_hits_dict.items(): # hits is a list + for result in hits: # each result is a dict + key = (gibbs_name, result['ref_path']) + self.correlation_dict[key] = result['correlation'] + + + # `top_hits_df` now has all top N correlations per Gibbs matrix + # print(top_hits_df.head()) + # self.console.log(f"Correlation computation complete! Found {self.correlation_dict} correlations") + self.console.log(f"NP search complete! Found {len(results)} correlations") + return results + + + +# if __name__ == "__main__": +# search = NP_clusterSearchCLI() +# species = "mouse"#"human" +# n_clusters = "all" +# # gibbs_results = '/home/sson0030/xy86_scratch2/SANJAY/MHC-TP/data/P6215/mhcI_1224927/' +# gibbs_results = '/home/sson0030/xy86_scratch2/SANJAY/MHC-TP/data/9mersonly' +# output_path = '/home/sson0030/xy86_scratch2/SANJAY/MHC-TP/data/NPoutput_directoryMHC' +# hla_list = None +# threshold = 0.70 +# data_dir = "/home/sson0030/xy86_scratch2/SANJAY/MHC-TP/data/ref_data" +# db_df = pd.read_csv(f'{data_dir}/{species}.db') # Load your database +# cluster_search = search.compute_correlations_V3( +# db_df, gibbs_results, n_clusters, output_path, +# hla_list, threshold, data_dir, species +# ) + \ No newline at end of file diff --git a/cli/__init__.py b/cli/__init__.py index c99cc6d..d91d4a1 100644 --- a/cli/__init__.py +++ b/cli/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.1.0-beta" +__version__ = "1.1.1-beta" from warnings import filterwarnings @@ -6,10 +6,4 @@ install(show_locals=True) # type: ignore -# mzmlb is not used, so hdf5plugin is not needed -filterwarnings( - "ignore", - message="hdf5plugin is missing", - category=UserWarning, - module="psims.mzmlb", -) +from .HLAfreq import * \ No newline at end of file diff --git a/cli/cluster_search.py b/cli/cluster_search.py index 51abff3..85ef2a3 100644 --- a/cli/cluster_search.py +++ b/cli/cluster_search.py @@ -21,7 +21,7 @@ from jinja2 import Template import sys import altair as alt - +from cli.NumbaSearch import NP_clusterSearchCLI # import HTML import shutil @@ -47,6 +47,9 @@ def __init__(self): self.d3js_json = None self.db = None # databse of motif and matrix self.treshold_img = False + self.gibbs_results = None + self.kld_df = None + self._outfolder=None def generate_unique_random_ids(self, count: int) -> list: """ @@ -96,7 +99,7 @@ def parse_gibbs_output(self, gibbs_path: str, n_clusters: int) -> pd.DataFrame: for c_file in os.listdir(res_path): if f'{n_clusters}g.ds' in c_file: file_path = os.path.join(res_path, c_file) - df = pd.read_csv(file_path, sep='\s+') + df = pd.read_csv(file_path, sep="\s+") # output_path = f'data/sampledata_701014/res_{n_clusters}g.csv' # df.to_csv(output_path, index=False) logging.info(f"Data parsed for No clusters {n_clusters}") @@ -157,6 +160,25 @@ def _check_HLA_DB(self, HLA_list: list, ref_folder: str) -> bool: return True @staticmethod + # def format_input_gibbs(gibbs_matrix: str) -> pd.DataFrame: + # """ + # Format the Gibbs output matrix. + + # :param gibbs_matrix: Path to the Gibbs matrix file + # :return: Processed DataFrame + # """ + # df = pd.read_csv(gibbs_matrix) + # amino_acids = df.iloc[0, 0].split() + # df.iloc[:, 0] = df.iloc[:, 0].str.replace( + # r"^\d+\s\w\s", "", regex=True) + # new_df = pd.DataFrame(df.iloc[1:, 0].str.split( + # expand=True).values, columns=amino_acids) + # new_df.reset_index(drop=True, inplace=True) + # new_df = new_df.apply(pd.to_numeric, errors='coerce') + # print(f"Formatted Gibbs matrix from {gibbs_matrix}") + # print(new_df) + # breakpoint() + # return new_df def format_input_gibbs(gibbs_matrix: str) -> pd.DataFrame: """ Format the Gibbs output matrix. @@ -164,15 +186,100 @@ def format_input_gibbs(gibbs_matrix: str) -> pd.DataFrame: :param gibbs_matrix: Path to the Gibbs matrix file :return: Processed DataFrame """ - df = pd.read_csv(gibbs_matrix) - amino_acids = df.iloc[0, 0].split() - df.iloc[:, 0] = df.iloc[:, 0].str.replace( - r"^\d+\s\w\s", "", regex=True) - new_df = pd.DataFrame(df.iloc[1:, 0].str.split( - expand=True).values, columns=amino_acids) - new_df.reset_index(drop=True, inplace=True) - new_df = new_df.apply(pd.to_numeric, errors='coerce') - return new_df + try: + # Read file line by line to handle inconsistent formats + with open(gibbs_matrix, 'r') as f: + lines = f.readlines() + + # Find amino acid header and data start + amino_acids = None + data_start_idx = None + + for i, line in enumerate(lines): + line = line.strip() + + # Look for amino acid header line + if 'A R N D C Q E G H I L K M F P S T W Y V' in line: + amino_acids = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', + 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] + data_start_idx = i + 1 + break + elif line.startswith('A R N D C Q E G H I L K M F P S T W Y V'): + amino_acids = line.split() + data_start_idx = i + 1 + break + + # Fallback: if no header found, try original CSV approach + if amino_acids is None: + try: + df = pd.read_csv(gibbs_matrix) + amino_acids = df.iloc[0, 0].split() + data_start_idx = 1 + lines = [df.iloc[0, 0]] + [df.iloc[i, 0] for i in range(1, len(df))] + except: + # Use standard amino acid order as last resort + amino_acids = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', + 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] + data_start_idx = 0 + + # Parse data lines + data_rows = [] + for i in range(data_start_idx, len(lines)): + line = lines[i].strip() + + # Skip empty lines and comments + if not line or line.startswith('#') or 'position-specific' in line.lower(): + continue + + # Split line into parts + parts = line.split() + + # Skip lines that don't look like data + if len(parts) < len(amino_acids): + continue + + # Extract numeric values (skip position number and amino acid identifier) + values = None + if len(parts) == len(amino_acids) + 2: # position + AA + 20 values + values = parts[2:] + elif len(parts) == len(amino_acids) + 1: # position + 20 values + values = parts[1:] + elif len(parts) == len(amino_acids): # just 20 values + values = parts + elif len(parts) > len(amino_acids) + 2: # more than expected, take middle 20 + start_idx = len(parts) - len(amino_acids) + values = parts[start_idx:] + + # Validate that we have exactly 20 numeric values + if values and len(values) == len(amino_acids): + try: + # Test if all values are numeric + [float(v) for v in values] + data_rows.append(values) + except ValueError: + continue + + if not data_rows: + raise ValueError(f"No valid data rows found in {gibbs_matrix}") + + # Create DataFrame + new_df = pd.DataFrame(data_rows, columns=amino_acids) + new_df = new_df.apply(pd.to_numeric, errors='coerce') + new_df.reset_index(drop=True, inplace=True) + + # print(f"Formatted Gibbs matrix from {gibbs_matrix}") + # print(new_df) + + return new_df + + except Exception as e: + print(f"Error processing {gibbs_matrix}: {str(e)}") + # Return empty DataFrame with standard amino acid columns as fallback + amino_acids = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', + 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] + return pd.DataFrame(columns=amino_acids) + + @staticmethod def amino_acid_order_identical(df1: pd.DataFrame, df2: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: @@ -376,7 +483,49 @@ def compute_correlations( self.console.log( f"Cluster Search Preprocess completed in {elapsed_time:.2f} seconds." ) - + def read_KLD_file(self,file_path): + """ + Read & Prase Gibbs gibbs.KLDvsClusters file. + Args: + file_path (str): The path to the Gibbs gibbs.KLDvsClusters file. + Returns: + list: A list of tuples containing the cluster number and KLD value. + """ + try: + # Initialize an empty dictionary to store data + data_dict = {'cluster': []} + + with open(file_path, 'r') as file: + lines = file.readlines() + for line in lines[1:]: # Skip the header line + parts = line.strip().split('\t') + cluster_number = int(parts[0]) + kld_values = [float(value) for value in parts[1:]] + + # Ensure the dictionary has enough group columns + for i in range(1, len(kld_values) + 1): + group_key = f'group{i}' + if group_key not in data_dict: + data_dict[group_key] = [] + + # Populate the row data + data_dict['cluster'].append(cluster_number) + for i, kld_value in enumerate(kld_values, start=1): + data_dict[f'group{i}'].append(kld_value if kld_value > 0 else 0) + + # Fill remaining groups with zeros if necessary + for j in range(len(kld_values) + 1, len(data_dict) - 1): + data_dict[f'group{j}'].append(0) + + # Convert the dictionary to a DataFrame + df = pd.DataFrame(data_dict) + df.loc[:, 'total'] = df.iloc[:, 1:].sum(axis=1) + df.reset_index(drop=True, inplace=True) + # print(df) + return df + except Exception as e: + print(f"Error reading file: {e}") + sys.exit(1) def compute_correlations_v2( self, db: pd.DataFrame, @@ -397,7 +546,10 @@ def compute_correlations_v2( #Update to self self.threshold = threshold self.treshold_img = treshold_img - + self.gibbs_results = gibbs_results + self.kld_df = self.read_KLD_file( + os.path.join(self.gibbs_results, 'images', 'gibbs.KLDvsClusters.tab') + ) gibbs_result_matrix = os.path.join(gibbs_results, "matrices") should_process = False if os.path.exists(gibbs_result_matrix) and any(".mat" in file for file in os.listdir(gibbs_result_matrix)): @@ -1151,7 +1303,7 @@ def insert_script_png_json(self, script_data_path, img_fallback_path, div_id): return script_template.render(script_data_path=script_data_path, img_fallback_path=img_fallback_path, div_id=div_id) - def render_hla_section(self, hla_name, corr, best_cluster_img, naturally_presented_img): + def render_hla_section(self, hla_name, corr, best_cluster_img, naturally_presented_img,kld_clust_group_kld): if str(self.species).lower() == "human": hla_name = f"HLA-{hla_name}" @@ -1160,7 +1312,7 @@ def render_hla_section(self, hla_name, corr, best_cluster_img, naturally_present template = Template('''
-

{{ hla_name }} PCC = {{corr}}

+

Best matched allotype is {{ hla_name }} with PCC = {{corr}} and KLD= {{ kld }}

@@ -1208,7 +1360,7 @@ def render_hla_section(self, hla_name, corr, best_cluster_img, naturally_present
''') - return template.render(hla_name=hla_name, corr=corr, best_cluster_img=best_cluster_img, naturally_presented_img=naturally_presented_img) + return template.render(hla_name=hla_name, corr=corr, best_cluster_img=best_cluster_img, naturally_presented_img=naturally_presented_img,kld= kld_clust_group_kld) def make_datatable(self, correlation_dict): df = pd.DataFrame(correlation_dict.items(), @@ -1219,9 +1371,26 @@ def make_datatable(self, correlation_dict): df['HLA'] = df['HLA'].apply( lambda x: x.split('/')[-1].replace('.txt', '')) df['Correlation'] = df['Correlation'].apply(lambda x: round(x, 2)) + #Add KLD from kld_clust_group_kld + if self.kld_df is not None: + try: + df['KLD'] = df['Cluster'].apply( + lambda x: self.kld_df.loc[ + self.kld_df['cluster'] == int(str(x).split('of')[-1]), + f'group{str(x).split("of")[0]}' + ].values[0] if f'group{str(x).split("of")[0]}' in self.kld_df.columns else None + ) + except KeyError as e: + # Handle the case where the key is not found + self.console.log( + f"KeyError: The key KLD group was not found in the Corr DataFrame." + ) + df['KLD'] = 'NA' + else: + df['KLD'] = 'NA' df = df.sort_values(by='Correlation', ascending=False) df = df.reset_index(drop=True) - df = df[['Cluster', 'HLA', 'Correlation']] + df = df[['Cluster', 'HLA', 'Correlation', 'KLD']] return df def process_correlation_data(self, df=None): @@ -1302,6 +1471,8 @@ def process_correlation_data(self, df=None): def save_correlation_data(self): """Save correlation data to CSV and JSON files""" try: + if not os.path.exists(os.path.join(self._outfolder, 'corr-data')): + os.makedirs(os.path.join(self._outfolder, 'corr-data')) # Save dataframe to CSV self.corr_df.to_csv(os.path.join( self._outfolder, 'corr-data', 'corr_matrix.csv'), index=False) @@ -1540,11 +1711,37 @@ def _create_carousel_for_cluster(self, carousel_id, cluster_num, group_data): total_slides = len(group_nums) # Start building the carousel HTML + # Convert cluster_num to words for display + def number_to_words(n): + words = { + 1: "One", 2: "Two", 3: "Three", 4: "Four", 5: "Five", 6: "Six", + 7: "Seven", 8: "Eight", 9: "Nine", 10: "Ten" + } + return words.get(n, str(n)) + + cluster_num_word = number_to_words(int(cluster_num)) + # KLD = self.kld_df[] + if self.kld_df is not None: + kld_clust_df = self.kld_df[self.kld_df['cluster'] == cluster_num] + if not kld_clust_df.empty: + kld = round(kld_clust_df['total'].values[0], 2) + else: + kld = None + # Determine singular/plural for group(s) + group_label = "group" if len(group_nums) == 1 else "groups" + group_range = f"{min(group_nums)}" if len(group_nums) == 1 else f"{min(group_nums)} to {max(group_nums)}" + cluster_label = f"{cluster_num_word} cluster output" + carousel_html = f"""
-

Cluster {cluster_num}

-
""" @@ -1620,7 +1835,6 @@ def render_clustered_results(self, highest_corr_per_row, gibbs_out): """ Renders the complete HTML for all clustered results with carousels and necessary JavaScript. """ - html_card = """
@@ -1649,20 +1863,49 @@ def render_clustered_results(self, highest_corr_per_row, gibbs_out): # New functions ends here - def generate_html_layout(self, correlation_dict, db, gibbs_out, immunolyser=False): + def generate_html_layout(self, correlation_dict, db, gibbs_out,output_path, immunolyser=False): """ Generate an image grid for the correlation results. """ + if output_path: + if not os.path.exists(output_path): + # Create the output folder if it doesn't exist + self._outfolder = self._make_dir( + output_path, self.generate_unique_random_ids(6)[0] + ) + CONSOLE.log( + f"Output folder [bold yellow]{output_path} created successfully.", style="bold green") + else: + # Use existing folder and set _outfolder + self._outfolder = output_path + CONSOLE.log( + f"Output folder [bold yellow]{output_path} already exists. Using the existing folder.", style="bold yellow") + else: + # If no output_path provided, create a default one + self._outfolder = self._make_dir( + "data/output_default", self.generate_unique_random_ids(6)[0] + ) + CONSOLE.log( + f"No output path provided. Created default output folder: {self._outfolder}", style="bold yellow") + highest_corr_per_row = self.find_highest_correlation_for_each_row( correlation_dict ) # print(highest_corr_per_row) display_search_results(highest_corr_per_row, self.threshold) # print(highest_corr_per_row) - + + df = self.make_datatable(correlation_dict) + self.corr_df = df + self.d3js_json = self.save_correlation_data() + self.correlation_dict = correlation_dict + if self.kld_df is None: + self.kld_df = self.read_KLD_file( + os.path.join(gibbs_out, 'images', 'gibbs.KLDvsClusters.tab') + ) + # added to self self.db = db - if not os.path.exists(os.path.join(self._outfolder, 'cluster-img')): os.makedirs(os.path.join(self._outfolder, 'cluster-img')) if not os.path.exists(os.path.join(self._outfolder, 'allotypes-img')): @@ -1742,6 +1985,13 @@ def generate_html_layout(self, correlation_dict, db, gibbs_out, immunolyser=Fals $(targetId).carousel(parseInt(slideIndex)); }); }); + + document.addEventListener('DOMContentLoaded', function () { + var popoverTriggerList = [].slice.call(document.querySelectorAll('[data-bs-toggle="popover"]')); + popoverTriggerList.forEach(function (popoverTriggerEl) { + new bootstrap.Popover(popoverTriggerEl); + }); + }); """ body_end_1 = Template("""
@@ -1756,17 +2006,19 @@ def generate_html_layout(self, correlation_dict, db, gibbs_out, immunolyser=Fals - + @@ -2599,8 +2852,8 @@ def run_cluster_search(args): # if args.output_folder is None: # os.makedirs("output", exist_ok=True) # output_folder = "output" - cluster_search = ClusterSearch() - cluster_search.console.rule( + cluster_search_html = ClusterSearch() + cluster_search_html.console.rule( "[bold red]Stage 1/2: Data processing for correlation matrices." ) # print(args.species) @@ -2609,7 +2862,8 @@ def run_cluster_search(args): f"Species provided: [bold yellow]{args.species}", style="bold green") CONSOLE.log( f"Loading reference databse for [bold yellow]{args.species}", style="bold green") - db = cluster_search._db_loader(args.reference_folder, args.species) + db = cluster_search_html._db_loader(args.reference_folder, args.species) + CONSOLE.log(f"Reference database loaded successfully.", style="bold green") else: @@ -2637,16 +2891,19 @@ def run_cluster_search(args): spinner="squish", spinner_style="yellow", ) + DB_HLA_list = [hla.lower() for hla in db['formatted_allotypes'].values] for u_hla in str(args.hla_types).split(','): - if u_hla in db['formatted_allotypes'].values: + u_hla = u_hla.replace("[","").replace("]","").replace("'","") + if u_hla.lower() in DB_HLA_list: status.update( status=f"[bold blue] HLA/MHC {u_hla} found in databse", spinner="squish", spinner_style="yellow", ) u_hla_list.append(u_hla) - CONSOLE.log( - f"HLA/MHC [bold yellow]{u_hla}[yellow] found in databse", style="bold green") + # time.sleep(10) + # CONSOLE.log( + # f"HLA/MHC [bold yellow]{u_hla}[yellow] found in databse", style="bold green") else: status.update( status=f"[bold blue] HLA/MHC {u_hla} not found in databse", @@ -2663,21 +2920,39 @@ def run_cluster_search(args): args.hla_types = None CONSOLE.log( f"No HLA/MHC allotypes types provided. Using all available HLA types from the reference folder.", style="bold yellow") - + CONSOLE.log(f"calculating compute_correlations.", style="bold blue") - cluster_search.compute_correlations_v2( - db, - args.gibbs_folder, - args.n_clusters, - args.output, - args.hla_types, - args.threshold, - args.treshold_img - ) + + if args.Searchtype == "Numba": + CONSOLE.log(f"Using Numba for correlation calculation.", style="bold blue") + cluster_search = NP_clusterSearchCLI(cache_dir = os.path.join(args.reference_folder,"numba_cache")) + cluster_search.compute_correlations_V3( + db, + args.gibbs_folder, + args.n_clusters, + args.output, + args.hla_types, + args.threshold, + args.NumbaDB, + args.species, + args.topNHits + ) + # breakpoint() + else: + cluster_search_html.compute_correlations_v2( + db, + args.gibbs_folder, + args.n_clusters, + args.output, + args.hla_types, + args.threshold, + args.treshold_img + ) # cluster_search.generate_image_grid(cluster_search.correlation_dict,db) - cluster_search.generate_html_layout( - cluster_search.correlation_dict, db, args.gibbs_folder, args.immunolyser) + cluster_search_html.generate_html_layout( + cluster_search.correlation_dict, db, args.gibbs_folder, args.output, args.immunolyser + ) # breakpoint() @@ -2689,11 +2964,11 @@ def run_cluster_search(args): # args.hla_types, # ) - cluster_search.console.rule( + cluster_search_html.console.rule( "[bold red]Stage 2/2: Finding best matching Naturally presented HLA ." ) # if args.output_folder is None: - cluster_search.plot_heatmap(args.output) + # cluster_search_html.plot_heatmap(args.output) # cluster_search.console.rule("[bold red]Stage 3/4: Cheking the HLA.") @@ -2722,22 +2997,22 @@ def run_cluster_search(args): # if __name__ == "__main__": - # # ClusterSearch().compute_correlations( - # # "data/9mersonly", - # # "data/ref_data/Gibbs_motifs_mouse/output_matrices", - # # "all", - # # "data/outputM", - # # ) - - # # ClusterSearch()._compute_and_log_correlation( - # # "data/9mersonly", - # # "data/ref_data/Gibbs_motifs_human/output_matrices_human", - # # "cluster_1of5.mat", - # # "HLA_A_02_01.txt", - # # ) - # # print(sys.argv) - - # # print(ClusterSearch()._db_loader("data/ref_data/","mouse")) + # ClusterSearch().compute_correlations( + # "data/9mersonly", + # "data/ref_data/Gibbs_motifs_mouse/output_matrices", + # "all", + # "data/outputM", + # ) + + # ClusterSearch()._compute_and_log_correlation( + # "data/9mersonly", + # "data/ref_data/Gibbs_motifs_human/output_matrices_human", + # "cluster_1of5.mat", + # "HLA_A_02_01.txt", + # ) + # print(sys.argv) + + # print(ClusterSearch()._db_loader("data/ref_data/","mouse")) # run_cluster_search( # argparse.Namespace( @@ -2753,3 +3028,23 @@ def run_cluster_search(args): # immunolyser=False # ) # ) + + + # run_cluster_search( + # argparse.Namespace( + # credits=False, + # gibbs_folder="data/P6215/mhcI_1224927", + # species="human", + # Searchtype="Numba", + # hla_types=None, + # reference_folder="data/ref_data/", + # data_dir="data/ref_data/human_db", + # threshold=0.7, + # log=False, + # n_clusters="all", + # output="data/outputHumanTest", + # processes=4, + # version=False, + # immunolyser=False + # ) + # ) diff --git a/cli/database_gen.py b/cli/database_gen.py index 2da76d9..aa2e83f 100644 --- a/cli/database_gen.py +++ b/cli/database_gen.py @@ -131,13 +131,160 @@ def Database_gen(config_file): ) sys.exit(0) -# Database_gen("config.json") + +def generate_HLA_freq_database(out_put_ditectory=None): + """Generate HLA frequency database.""" + if out_put_ditectory is None: + CONSOLE.log("Output directory not specified. Please provide a valid output directory.", style="red") + sys.exit(1) + try: + from .HLAfreq import get_list, makeURL, getAFdata, url_encode_name + except ImportError: + CONSOLE.log("HLAfreq function Not avilable please check....", style="red") + sys.exit(1) + + # from .HLAfreq import HLAfreq_pymc as HLAhdi + import pandas as pd + import numpy as np + import matplotlib.pyplot as plt + + + #country_list = ["Australia", "Thailand", "United States of America", "United Kingdom", "Germany", "France", "Italy", "Spain", "Netherlands", "Sweden", "Norway", "Finland", "Denmark", "Belgium", "Switzerland", "Austria", "Poland", "Czech Republic", "Hungary", "Portugal"] + # country_list = ["Australia"] + CONSOLE.log("Processing HLA frequency data for countries...", style="blue") + country_list = get_list("country") + # print(country_list) + # for country in country_list: + # print(f"Processing {url_encode_name(country)}...") + CONSOLE.log(f"Total countries to process: {len(country_list)}", style="yellow") + for country in country_list: + CONSOLE.log(f"Processing {url_encode_name(country)}...", style="blue") + base_url = makeURL(country) + try: + CONSOLE.log(f"Base URL for {country}: {base_url}", style="green") + aftab = getAFdata(base_url) + aftab.to_csv(f"{out_put_ditectory}/{country}_raw.csv", index=False) + CONSOLE.log(f"Raw data for {country} saved to example/{country}_raw.csv", style="green") + except Exception as e: + CONSOLE.log(f"Error processing {country}: {e}", style="red") + continue + + +def add_allefreq_to_db(db_pat, allefreq_path,allefreq_output,loci="A",freq=0.01): + try: + from .HLAfreq import only_complete, combineAF, decrease_resolution,check_resolution + import matplotlib.pyplot as plt + import pandas as pd + except ImportError: + CONSOLE.log("HLAfreq function Not available please check....", style="red") + sys.exit(1) + """Add allele frequency data to the database.""" + if not os.path.exists(db_pat): + CONSOLE.log(f"Database file {db_pat} does not exist.", style="red") + sys.exit(1) + + if not os.path.exists(allefreq_path): + CONSOLE.log(f"Allele frequency file {allefreq_path} does not exist.", style="red") + sys.exit(1) + + # Load the database and allele frequency data + db = pd.read_csv(db_pat) + # print(db) + # allefreq = pd.read_csv(allefreq_path) + country_files = [f for f in os.listdir(allefreq_path) if f.endswith('_raw.csv')] + if not country_files: + CONSOLE.log(f"No allele frequency files found in {allefreq_path}.", style="red") + sys.exit(1) + country_count = 0 + cafs = [] + for country in country_files: + country_count += 1 + country_path = os.path.join(allefreq_path, country) + allefreq = pd.read_csv(country_path) + CONSOLE.log(f"Processing allele frequency data for {country}... {country_count}/{len(country_files)}", style="blue") + # print(allefreq.allele) + # print(allefreq) + # Drop any incomplete studies + aftab = only_complete(allefreq) + check = check_resolution(aftab) + if check: + aftab = decrease_resolution(aftab, 2) + afloc = aftab[aftab.loci==loci] + if afloc.empty: + CONSOLE.log(f"No A locus data found for {country}. Skipping...", style="yellow") + continue + caf = combineAF(afloc) + caf['country'] = country.split('_')[0] # Extract country name from filename + cafs.append(caf) + else: + continue + # Ensure all alleles have the same resolution + # aftab = decrease_resolution(aftab, 2) + cafs = pd.concat(cafs, ignore_index=True) + cafs.to_csv(f"{allefreq_output}/01HLA_freq_by_country_all_HLA-{loci}.csv", index=False) + international = combineAF(cafs, datasetID='country') + mask = international.allele_freq > freq + international[mask].plot.barh('allele', 'allele_freq') + plt.savefig(f"{allefreq_output}/01HLA_freq_by_country_Top20_HLA-{loci}_freq>{freq}.png") + # plt.show() +# # # Database_gen("config.json") +import os + +import os +import pandas as pd + +def _formate_mouse_DB(ref_db_path, db_path, out_db_path=None): + # Load DB + db = pd.read_csv(db_path) + + # Walk through motif + matrices folders and rename files + for folder in os.listdir(ref_db_path): + if folder in ["motif", "matrices"]: + folder_path = os.path.join(ref_db_path, folder) + for fname in os.listdir(folder_path): + if fname.startswith("MHC_"): + old_path = os.path.join(folder_path, fname) + new_name = fname.replace("MHC_", "", 1) # replace prefix + new_path = os.path.join(folder_path, new_name) + os.rename(old_path, new_path) + print(f"Renamed: {old_path} → {new_path}") + + # Update DB entries that reference this filename + db = db.replace(fname, new_name) + + # Save updated DB + if out_db_path is None: + out_db_path = db_path # overwrite input DB + db.to_csv(out_db_path, index=False) + print(f"Database updated and saved to {out_db_path}") + # if __name__ == "__main__": -# config_file = "config.json" -# config = _prase_config_file(config_file) -# _check_ref_files(config) -# hla_list = _HLA_liist(config) -# # print(hla_list) -# # print(config) -# CONSOLE.log("Config file parsed successfully") -# sys.exit(0) \ No newline at end of file + # config_file = "config.json" + # config = _prase_config_file(config_file) + # _check_ref_files(config) + # path="/home/sson0030/xy86_scratch2/SANJAY/MHC-TP/data/ref_data/Gibbs_motifs_mouse" + # path="data/ref_data/Gibbs_motifs_mouse" + # db_path="data/ref_data/mouse.db" + # _formate_mouse_DB(path,db_path) + # hla_list = _HLA_liist(config) + # print(hla_list) + # print(config) + # CONSOLE.log("Config file parsed successfully") +# sys.exit(0) + +# ## HLA frequency database generation + + +# # Combine studies within country +# caf = combineAF(aftab) +# # Add country name to dataset, this is used as `datasetID` going forward +# caf['country'] = country +# cafs.append(caf) + +# cafs = pd.concat(cafs, ignore_index=True) +# international = combineAF(cafs, datasetID='country') +# print(international) +# db_pat = "/home/sson0030/xy86_scratch2/SANJAY/MHC-TP/data/ref_data/human.db" +# allefreq_path = "/home/sson0030/xy86_scratch2/SANJAY/MHC-TP/data/ref_data/HLAfreq/byCountry/" +# allefreq_output = "/home/sson0030/xy86_scratch2/SANJAY/MHC-TP/data/ref_data/HLAfreq/output" +# add_allefreq_to_db(db_pat, allefreq_path,allefreq_output,loci="C",freq=0.01) \ No newline at end of file diff --git a/cli/html_config.py b/cli/html_config.py index 9b26a1c..9711f67 100644 --- a/cli/html_config.py +++ b/cli/html_config.py @@ -9,7 +9,7 @@ - Clust-search {{ version }} + MHC-TP {{ version }} @@ -192,6 +192,7 @@