From f287f5a5d3914811ff40c4b1af944b3c982dbfc3 Mon Sep 17 00:00:00 2001 From: RalphHan Date: Tue, 29 Aug 2023 17:26:33 +0800 Subject: [PATCH] cache text embedding in inference --- model/mdm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/model/mdm.py b/model/mdm.py index 14fd5bda5..38881e40c 100644 --- a/model/mdm.py +++ b/model/mdm.py @@ -148,7 +148,11 @@ def forward(self, x, timesteps, y=None): force_mask = y.get('uncond', False) if 'text' in self.cond_mode: - enc_text = self.encode_text(y['text']) + if 'text' in y: + text_embed = self.encode_text(y['text']) + del y['text'] + y['text_embed'] = text_embed + enc_text = y["text_embed"] emb += self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) if 'action' in self.cond_mode: action_emb = self.embed_action(y['action'])