Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
224 changes: 224 additions & 0 deletions chai_lab/data/dataset/msas/colabfold_codes/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from __future__ import annotations

import warnings
from Bio import BiopythonDeprecationWarning # what can possibly go wrong...
warnings.simplefilter(action='ignore', category=BiopythonDeprecationWarning)

import logging
import random
from pathlib import Path
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
import pandas

logger = logging.getLogger(__name__)

def parse_fasta(fasta_string: str) -> Tuple[List[str], List[str]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.

Arguments:
fasta_string: The string contents of a FASTA file.

Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
sequences = []
descriptions = []
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith("#"):
continue
if line.startswith(">"):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append("")
continue
elif not line:
continue # Skip blank lines.
sequences[index] += line

return sequences, descriptions

def get_queries(
input_path: Union[str, Path], sort_queries_by: str = "length"
) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]:
"""Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple
of job name, sequence and the optional a3m lines"""

input_path = Path(input_path)
if not input_path.exists():
raise OSError(f"{input_path} could not be found")

if input_path.is_file():
if input_path.suffix == ".csv" or input_path.suffix == ".tsv":
sep = "\t" if input_path.suffix == ".tsv" else ","
df = pandas.read_csv(input_path, sep=sep, dtype=str)
assert "id" in df.columns and "sequence" in df.columns
queries = [
(seq_id, sequence.upper().split(":"), None)
for seq_id, sequence in df[["id", "sequence"]].itertuples(index=False)
]
for i in range(len(queries)):
if len(queries[i][1]) == 1:
queries[i] = (queries[i][0], queries[i][1][0], None)
elif input_path.suffix == ".a3m":
(seqs, header) = parse_fasta(input_path.read_text())
if len(seqs) == 0:
raise ValueError(f"{input_path} is empty")
query_sequence = seqs[0]
# Use a list so we can easily extend this to multiple msas later
a3m_lines = [input_path.read_text()]
queries = [(input_path.stem, query_sequence, a3m_lines)]
elif input_path.suffix in [".fasta", ".faa", ".fa"]:
(sequences, headers) = parse_fasta(input_path.read_text())
queries = []
for sequence, header in zip(sequences, headers):
sequence = sequence.upper()
if sequence.count(":") == 0:
# Single sequence
queries.append((header, sequence, None))
else:
# Complex mode
queries.append((header, sequence.upper().split(":"), None))
else:
raise ValueError(f"Unknown file format {input_path.suffix}")
else:
assert input_path.is_dir(), "Expected either an input file or a input directory"
queries = []
for file in sorted(input_path.iterdir()):
if not file.is_file():
continue
if file.suffix.lower() not in [".a3m", ".fasta", ".faa"]:
logger.warning(f"non-fasta/a3m file in input directory: {file}")
continue
(seqs, header) = parse_fasta(file.read_text())
if len(seqs) == 0:
logger.error(f"{file} is empty")
continue
query_sequence = seqs[0]
if len(seqs) > 1 and file.suffix in [".fasta", ".faa", ".fa"]:
logger.warning(
f"More than one sequence in {file}, ignoring all but the first sequence"
)

if file.suffix.lower() == ".a3m":
a3m_lines = [file.read_text()]
queries.append((file.stem, query_sequence.upper(), a3m_lines))
else:
if query_sequence.count(":") == 0:
# Single sequence
queries.append((file.stem, query_sequence, None))
else:
# Complex mode
queries.append((file.stem, query_sequence.upper().split(":"), None))

# sort by seq. len
if sort_queries_by == "length":
queries.sort(key=lambda t: len("".join(t[1])))

elif sort_queries_by == "random":
random.shuffle(queries)

is_complex = False
for job_number, (_, query_sequence, a3m_lines) in enumerate(queries):
if isinstance(query_sequence, list):
is_complex = True
break
if a3m_lines is not None and a3m_lines[0].startswith("#"):
a3m_line = a3m_lines[0].splitlines()[0]
tab_sep_entries = a3m_line[1:].split("\t")
if len(tab_sep_entries) == 2:
query_seq_len = tab_sep_entries[0].split(",")
query_seq_len = list(map(int, query_seq_len))
query_seqs_cardinality = tab_sep_entries[1].split(",")
query_seqs_cardinality = list(map(int, query_seqs_cardinality))
is_single_protein = (
True
if len(query_seq_len) == 1 and query_seqs_cardinality[0] == 1
else False
)
if not is_single_protein:
is_complex = True
break
return queries, is_complex

def pair_sequences(
a3m_lines: List[str], query_sequences: List[str], query_cardinality: List[int]
) -> str:
a3m_line_paired = [""] * len(a3m_lines[0].splitlines())
for n, seq in enumerate(query_sequences):
lines = a3m_lines[n].splitlines()
for i, line in enumerate(lines):
if line.startswith(">"):
if n != 0:
line = line.replace(">", "\t", 1)
a3m_line_paired[i] = a3m_line_paired[i] + line
else:
a3m_line_paired[i] = a3m_line_paired[i] + line * query_cardinality[n]
return "\n".join(a3m_line_paired)

def pad_sequences(
a3m_lines: List[str], query_sequences: List[str], query_cardinality: List[int]
) -> str:
_blank_seq = [
("-" * len(seq))
for n, seq in enumerate(query_sequences)
for _ in range(query_cardinality[n])
]
a3m_lines_combined = []
pos = 0
for n, seq in enumerate(query_sequences):
for j in range(0, query_cardinality[n]):
lines = a3m_lines[n].split("\n")
for a3m_line in lines:
if len(a3m_line) == 0:
continue
if a3m_line.startswith(">"):
a3m_lines_combined.append(a3m_line)
else:
a3m_lines_combined.append(
"".join(_blank_seq[:pos] + [a3m_line] + _blank_seq[pos + 1 :])
)
pos += 1
return "\n".join(a3m_lines_combined)

def pair_msa(
query_seqs_unique: List[str],
query_seqs_cardinality: List[int],
paired_msa: Optional[List[str]],
unpaired_msa: Optional[List[str]],
) -> str:
if paired_msa is None and unpaired_msa is not None:
a3m_lines = pad_sequences(
unpaired_msa, query_seqs_unique, query_seqs_cardinality
)
elif paired_msa is not None and unpaired_msa is not None:
a3m_lines = (
pair_sequences(paired_msa, query_seqs_unique, query_seqs_cardinality)
+ "\n"
+ pad_sequences(unpaired_msa, query_seqs_unique, query_seqs_cardinality)
)
elif paired_msa is not None and unpaired_msa is None:
a3m_lines = pair_sequences(
paired_msa, query_seqs_unique, query_seqs_cardinality
)
else:
raise ValueError(f"Invalid pairing")
return a3m_lines

def msa_to_str(
unpaired_msa: List[str],
paired_msa: List[str],
query_seqs_unique: List[str],
query_seqs_cardinality: List[int],
) -> str:
msa = "#" + ",".join(map(str, map(len, query_seqs_unique))) + "\t"
msa += ",".join(map(str, query_seqs_cardinality)) + "\n"
# build msa with cardinality of 1, it makes it easier to parse and manipulate
query_seqs_cardinality = [1 for _ in query_seqs_cardinality]
msa += pair_msa(query_seqs_unique, query_seqs_cardinality, paired_msa, unpaired_msa)
return msa

2 changes: 2 additions & 0 deletions chai_lab/data/dataset/msas/colabfold_codes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def safe_filename(file: str) -> str:
return "".join([c if c.isalnum() or c in ["_", ".", "-"] else "_" for c in file])
Loading