Skip to content

ProtFlash einops error with variable length input during nn.DataParallel inference. #5

@ayaanhossain

Description

@ayaanhossain

Describe the bug
ProtFlash when wrapped with torch.nn.DataParallel crashes via an einops error.

To Reproduce

import torch
from ProtFlash.pretrain import load_prot_flash_base
from ProtFlash.utils import batchConverter

# Proteins we want to embed
data = [
    ("protein1", "FVLHFPCIHDHMQAVVQWITRMDMVYKFRADMYGKGKNNGFRDVHCVDHQQRFGWTQTMPDYPSGWEVASFKKSHIGRRASAPIGLGSYTKLKSMKIHMQKSTKWDWGMFKHQATVMMQEREQGRSENLGNYYTMNHCNTERRRHIITVIYMINPYRMRLKTNQKFLYNQCYFYKWVRWNTDAMTSMLNVTCNHSLYKQCWDHTYLLAYKDPQGSNEQNTDEGHVRMMVKECGPKILYYDCFTKVPMFHDLFGSWLMWLPILLQNLAVVDGYGTSVLMTEGDSYCEGVKFGNICTIFRRDASPSPINRIWVSYICLRSIGKSAGESKAFKYMVRVGWFKQGFCLYEKWSLDFFNEFEHPIGLNVWNNQKHHDTGFYLPFFRKDTIQHMSVEQWPDDECRYLNSIKNGAMTAEFSLMPFQCTPDKAQRFHEIFFKVSGWEWGMESGVALDVIQVMAEEWWLVFQIFFIHEHCYHCNVHTNRGHWSSVHGWNGFVAARDYIRLADRANHHWNNHIQENPASMERYIFGNLKRVAFINYQMAADPGFQTKVRRVDRYYNRVTVTIRISTWNKDPMLQTDKRTSSTYNFMQCRMWEWKNPNVDRYEFKRYSSKPEKDTSLACREFVVSGNEVARLRKDRPKHHLFFFWEDDYLGIAGFSAISTLEHEPGPPWHYIQPTKSHLLQNASRCYPALTFWIEHMQWEYMCWHPQEQGTDCMCPMPLSPIFCFEAGLDFADKPSPNWFSWMTCVMEARKGIYIAFDQSTPSIPCMMMRHMGFCGSWPPNKIPSIMKFGAKQQYRSKFHQVPLPANYVLHRPQLEAFMILFWHTIKNFEHSDKQQKAKTVQHEWMGFMFDPQKCHTEDCTCNYPIEDYDTIHRLWKQVKFYENYATTWIFLCTPLLQLTWWPRGIVCMEKDRTRQEMCHLQCTKVMMQTDLKTVFIMIDGILWMIPSKLEVVDHFWAEGMFMRTCCAWPMSNELQNLAFKHGPILSAHAWDTEGNTDEVGERAVVTGCIHYPWRRCANEYCWQCAIHGMCCQYFWHHAILPDDKADPYNDYLRENAPEWAPCQLYPHANKANEEHEAKECKAYKYVYMIRATGEFQATKRHCDPRFHWTIFGMMDALSDIDFNGFYDDWMIMFMPFTKYAGRVEYEIGIVRSFWKPQPNLPWILDHSPEMTCAIPGTTMHTVKLRQCQWRLLFPFHFHHARKWDQDMLKPTGTDGGYCIISLYRQRDKNECCSPFRQTWVDSIIPNIYDCTIKPKNADWKDHHFGRCGKMVWGMNDAQDLCRSNTKDDEEGENDPHQLWKGNPKMKG"),
    ("protein2", "SDAMFQANYAGTMDESVHKSDESDSLYWERMTCAMVISWVLQMQMMKPRWNCNGANWLKAQATTYFDGSEVDAERNHPPMLCFMCKCSCAPEHGERQFQQMNWKIMTQAKREYSTANMPTHMQRSHYLNAKNNCFFPGNRSDVDEVEQIKKMLFSNFKCWRGQQNAVYYYWRAKFETTLRCEVKMKETRQHPYPLFLHNDSHKMKRFREHNVENAASRHPHFKGTIGMRFWKPYKERKYIQMEGLCIITCTIPMGWERPVKKFWCVECFDIGMMCDEAGMGGKMGEILWICQCEDVYMNPNMYRAIYIEVPMICIKWGDIYLHQPGYNAIWIMNQVDLMDQCKGDTFFTVFQSYMVDCQNADNFQINKHHQLNNEWMRVWQSKRGQAWFPDYTECTYDPFRICGMSYPVWDASDDKWSLEHQDMGFQWDSPYWCRNCFWMHGTFDNYDFKEPFCFWWLITEPQQERNVDNTFRCRNEARFLAYMWDENNSWTCGWMVMAKTPCETYRCPTPNKFFRAAISPNMCGARRFGGVLSKQMFVQSWFANGKEIIHRNKCMDGICIFTVNPHKSHAHFIEDYCQRNWGDNPVHEWQSKKIREYDKGPQWSFEVLCTPSNNVLKNEDTELFENSKCFDVGHAKWIPDDRAKGMYATMRCVWPHNANSFFEECFPGSTRMGWEKKRMSWIVDWDTIYCKVFAGINFDDWYEEAYQRGPAQNGQPKDWFGQFHQNPKKPDDSQTFYFPDWGSGSKADQEFMFCPQRWCIEVMCQFRLRMTMAASHGNDKYAAFTIPRAQDVGDHDNGHSCIDLGATASWKGSYCYWCARANAQECADMASNGYPQDQMYLQHCVRHQLPWRFTLTDAYAWRDVDDNGGENFMLQEPRIFMNAMDMWMRGWTWRPMWTIAEAECYCPHDQYSYNQDTRFHNTFAGTFMTRGQDCSSYQKEGPHLDKECCLHCLDCKYFFFHRGPVWMQWSGCCVMGHFSNVGKCYLMIFWLQFSACTGFMEADQGALTHAELDYSHWGCWYGNDPQQMQMTDLIMCGYCPPKFCADETPPCIACKKHPRRNMLELYHRYLVLCCWKGKFVYARHWAEKNGACWSWMMSNKGLAERVCCASGGYKSQCSMNETWIATVLNEAWLFQAFKHDHHMMVSGVLDHRTHDECQLKQTSCNVSGIVMCVGELGHMMWLMQDQVGDHFNVNCGQGNQGPICHKVMVTDHLRINPWMKGYGCVLWPMLEMQYMWWLKCFVCIQFYPFRVWTARCAQLPTAMTKLLWAVTTILHFVLRCSPYCMLMAKGNEKPASRAAT"),
    ("protein3", "AKCTLGLDYGKCEARWYFSLAMIWRLYYYFVLRVDKCFITVIGMWYEDQGEIIMERQKDLQHFLHIHGRKKPHCTANFEVKHRMLMCYFQRHCGSWWENDQIQQYLHMDCLTMSDKNHAWNFFWCRFHFTHFFEWHIIYHHGGANEEGRNYHWLSMWGSASRTLVDKCSQSSGAWLAWYKSAFPCQSRSSNLTHCRFYITPKKFYITIAVDFVVWIRGRLAIKFHLADSHNMMILTYFQYSLYNTDCWMMDETGNGFDWWCHMIRFAHNDTALLGFAPHCIWFVFDFGCNHRLKKDQCRYSKGVFLSWWAVCNWPWHGGHQKRIGDMVVFLQNANCPCHNPKASWVRIVLCTGWHVMKGTAKHMFPIDEQFGPGIFHQMYSHGNFYGLCWWSHAQKYQMYSKAKMCQRLARTVNPWNKRNTMVCMQEDVRPLIDVHQQGEQALVQSAEFGEYDQNNEGNGQARNVYRYWGREKFLVKAGSLMKGNPVNYTAIIDSHDGFDCSTLNWTYAIRAGFIYGECECNILTNDHGTGCRICQEVTHMMPAGDQALRWGRAYVQTAAGVAMTREKMSIFSVLLVYIDNTWAGLCVCPWQSFTIAHHLFKIGQGSFVHDSVENNCKQYTCEKDMCGSDYFHRCATNHRTHMGNEYYFIDLCIMNNESQRKVISEFMMRGVGMWMQWYLIEMWLHCIEYMCATWSCCERSTCDVWRCAAQTPFRATVRVNKGWQEYANEPIKQRHYLDEQMHLTNAALITNRNHPPRDPFSSPWMNSCCFTMYVMQGAINDERVTHNYNGHRKTVRHLFHGASDHEDHWWILYEWCEHTTITSSSNCWNVYDIEAYWWVPLYMPERTQLEPQSWLTRFTANWPWAMSPVKQIVCFCQEGTHDEWMYEHDIMSHSPGWKDYWAVWTFPMPNPLQMYWNDSDHGSLGLKLHTFCMVHYKDNGWGMMFLCPRWGFMQYFFFKVTNNQCRFLQNGILRCKPAYHPMPRKTHMDPFQSVWCCWHGTHAREHNKVHEDEKAITSPVSPVCTGRWLAHIWWMWKITQLKQRLCDKNSESPAHGVWTMLGSAFCLGEHVMTWVHVWYHENWVMDIQMHHAGQFYANLTVMPQLDKFKNEFTHEQRTRYNGVFGTVHRWSPDTAYDIETIRAKWWEIIMMCSEGTQIWIMFDLMSIVNARNIKGNMKFILKGNELCKSTQRAHPSNTFDANPFSRHFRQDFLLDWIIAVEYLVDDSTMWIWVTQQMIRTMQKLVGEHDMPSYAFGFTCVYEGLMYIIMLLWSFHQIIDFRTSGCDILVVQGEMHMFNK"),
    ("protein4", "NDLHDCLWEIESIYFANPYVVPHKCKHEMHYLYPEQMGLVIGWRTWCFATWDRMAWVMQISKYADYGGSDLGEYEAQVCQMTCYPTHWPCWSGMIMYVEYKQSLKKVIFHEICARAMSRKEWTNAGDVEFYTEILPWVYEMAHDWADEDMHWFMPPSELSVNPVPWCHCAKEIKPWTHGSCMIDNGDPDQDKSESRVDRWSDTCNLMLLWKFYCLLLWIYARNFERPNANYKLVFKRASTRPRSVIKPETSPGHKPSQWYQHNIPTHLRLKMENRHCIEHPCREVRFMCCYSDFEIGGMDVRRDQTDRHSLSDIFFGICQMTQDAILFSNRCTEKKTFFHHDDEKHIRFWRDGDFHCKQGVHVEMSAWHLPPVPYKNKVTWLSHKRAQNLKLLIKMSMPTYCQEESATSYQVCVDGCIRFTNPWKHYMNGEDEERAPWSIATQQERDETYFWWSCGQYTIEYFHYMFKYKQFSSNSNHRTTEQNMHYVEEVLRTSACIHRDHEDDEMIKLRVMPRQCYDIPIFWAFYRARYCIKCEDDIKEGESEWMSPFNFCFFCWLIWEITPIFADSPKSDEIWNTFTAELAVVMGCNMRNCCSDWCRYKAVDMYRTDRPGICYFGLQLLNVITSYWSFSRGQFLSDHSKYRTWDYIHHPPKAYEANINVCFLSNINYLVSGFSEHENGPWTWKEWQGMNLKAKHTVRTIIWVAHRMFMVERRVGEMMISSFSEIHYKQCPHRWMCKCKAPGEVTTAHYNCCFYDVQWTGTDGECADCHAAYAACMTRNASPIVGPKLWQIKQHDKADISSFIDRFSTQGYQPTVAVSEGREKQPWCGYMFGHMFPKNDPWCGQNQMQNARKAIKAYEGKTRSETTKMRYKMLAPRWWLYMPNYRCAILQLLWHLMKYYLKNDIEDANPKHSMRQEMYDCWYDIPEGIPGIGRDEFWWLDRLTHYSERGHFQFPNRYYCMHPRVIWGEHMQTEGWKFYKYWWNSYFGPMAMDLNPIIQRSHIGMCKDMTYKAECRYMDYGCDFPVFQAAWNSTCDLNEKQKIAKSPVCMDNVHMQDVESPCRENYYVLMGWLSHRERCHLQHGKMPFNPSLTEQCVKPNQFDVENKDIADCTMWKENWMGWLWLVSRYEWEYEMSAATANWNLIPNLDPRQAQRDTMKRLMGWYCRHYDHERKNWLRCEWVPHETSFNNLWCFAQPYMDKAQQVHGAIKRRCFLHTRLGTKSYFPDCHVQLFCFDDQCMGKCEKQIMLFFIQNGHGPRCIKFHKGISTHTMENDPLRHTIRCCSTWGYNFSVFSSLPFWYYRHKYMMDSHAYD"),
]

# Tokenize
ids, batch_token, lengths = batchConverter(data)

# Load Model
model = load_prot_flash_base()

# Data Parallel Inference
model = torch.nn.DataParallel(model) # Commenting this line
model.to(device='cuda')              # and this one removes the error

# Embedding
with torch.no_grad():
    token_embedding = model(batch_token, lengths)

# Generate per-sequence representations via averaging
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_embedding[i, 0: len(seq) + 1].mean(0))

Expected behavior
Expect token embedding tensor of shape [4, L, 768] where L is the length of the longest sequence.

Error encountered

einops.EinopsError:  Error while processing rearrange-reduction pattern "b (g j) -> b g () j".
 Input tensor shape: torch.Size([1, 1315]). Additional info: {'j': 21}.
 Shape mismatch, can't divide axis of length 1315 in chunks of 21

Workaround
The only way to solve this problem on my end is to not use torch.nn.DataParallel.

Thanks for any insights and solutions!

Metadata

Metadata

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions