-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmultimodel.py
More file actions
144 lines (115 loc) · 5.98 KB
/
multimodel.py
File metadata and controls
144 lines (115 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
local_model_path = "./gpt2"
local_tokenizer_path = "./gpt2"
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim):
super(MLP, self).__init__()
all_dims = [input_dim] + hidden_dims + [output_dim]
self.linear_layers = nn.ModuleList()
for i in range(len(all_dims) - 1):
self.linear_layers.append(nn.Linear(all_dims[i], all_dims[i + 1]))
def forward(self, x):
for i, layer in enumerate(self.linear_layers):
x = layer(x)
if i < len(self.linear_layers) - 1:
x = F.gelu(x)
return x
class InstructTime(GPT2LMHeadModel):
def __init__(self, config, ecgTokenizers, text_embedding=50258):
super().__init__(config)
self.ecgTokenizers = ecgTokenizers
embed_vector = torch.empty(0, self.ecgTokenizers[0].hidden_dim)
for tokenizer in self.ecgTokenizers:
tokenizer_embed_vector = copy.deepcopy(tokenizer.quantize.embed).transpose(-1, 0)
embed_vector = torch.cat([embed_vector, tokenizer_embed_vector], dim=0)
self.embed_layer = nn.Embedding.from_pretrained(embed_vector)
self.text_embedding = text_embedding
self.embed = config.n_embd
self.config.pad_token_id = self.config.eos_token_id if self.config.pad_token_id is None else self.config.pad_token_id
self.projection_layers = nn.ModuleList()
for _ in ecgTokenizers:
mlp = MLP(self.ecgTokenizers[0].hidden_dim, [64, 128, 256, 512], self.embed)
mlp.apply(self.init_weights_kaiming)
self.projection_layers.append(mlp)
self.offsets = [self.text_embedding]
for tokenizer in self.ecgTokenizers:
self.offsets.append(self.offsets[-1] + tokenizer.n_embed)
@staticmethod
def init_weights_kaiming(m):
if type(m) == nn.Linear:
nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
m.bias.data.fill_(0.01)
def forward(self, *args, **kwargs):
input_ids = kwargs["input_ids"]
text_mask = torch.lt(input_ids, self.text_embedding)
ecg_mask = ~text_mask
text_ids = input_ids.clone()
text_ids[ecg_mask] = self.config.pad_token_id
text_embeddings = self.transformer.wte(text_ids)
text_embeddings.mul_(text_mask.float().unsqueeze(-1))
masked_ids = input_ids.clone()
masked_ids[text_mask] = 0
masked_ids[ecg_mask] -= self.text_embedding
ecg_embeddings = torch.zeros_like(text_embeddings)
for i, _ in enumerate(self.ecgTokenizers):
tokenizer_mask = (input_ids >= self.offsets[i]) & (input_ids < self.offsets[i + 1])
tokenizer_ids = input_ids.clone()
tokenizer_ids[~tokenizer_mask] = 0
tokenizer_ids[tokenizer_mask] -= self.offsets[i]
tokenizer_embeddings = self.embed_layer(tokenizer_ids)
tokenizer_embeddings = self.projection_layers[i](tokenizer_embeddings)
tokenizer_embeddings.mul_(tokenizer_mask.float().unsqueeze(-1))
ecg_embeddings.add_(tokenizer_embeddings)
kwargs["input_ids"] = None
kwargs["inputs_embeds"] = ecg_embeddings + text_embeddings
outputs = super().forward(*args, **kwargs)
return outputs
class MultiTokenizer:
def __init__(self, ecgTokenizers, dataset_keys=None) -> None:
self.textTokenizer = GPT2Tokenizer.from_pretrained(local_tokenizer_path)
# Keep tokenizer behavior consistent with the original; padding is controlled at the dataset level
new_special_tokens = ["<BET>", "<EET>"]
self.textTokenizer.add_special_tokens({"additional_special_tokens": new_special_tokens})
self.text_vocab_size = len(self.textTokenizer)
self.ecgTokenizers = ecgTokenizers
if dataset_keys is None:
dataset_keys = [str(idx) for idx in range(len(ecgTokenizers))]
if len(dataset_keys) != len(ecgTokenizers):
raise ValueError("dataset_keys and ecgTokenizers must have the same length.")
self.dataset_keys = [key.lower() if isinstance(key, str) else str(key) for key in dataset_keys]
self.key_to_index = {key: idx for idx, key in enumerate(self.dataset_keys)}
self.pad_token_id = self.textTokenizer.eos_token_id
self.eos_token_id = self.textTokenizer.eos_token_id
self.offsets = self._calculate_offsets()
def _calculate_offsets(self):
offsets = []
current_offset = self.text_vocab_size
for tokenizer in self.ecgTokenizers:
offsets.append(current_offset)
current_offset += tokenizer.n_embed
return offsets
def vocabSize_all(self):
return self.text_vocab_size + sum(tokenizer.n_embed for tokenizer in self.ecgTokenizers)
def get_model_index(self, dataset_key: str) -> int:
key = dataset_key.lower()
if key not in self.key_to_index:
raise ValueError(f"Dataset key '{dataset_key}' not found in tokenizer mapping.")
return self.key_to_index[key]
def encode(self, input, model_id=0):
if isinstance(input, str):
return self.textTokenizer(input)["input_ids"]
elif isinstance(input, torch.Tensor):
input = input.to("cpu")
if model_id < len(self.ecgTokenizers):
tokenizer_index = model_id
_, _, indices = self.ecgTokenizers[tokenizer_index](input)
return indices + self.offsets[tokenizer_index]
raise ValueError(f"Invalid model_id. Please provide a number between 0 and {len(self.ecgTokenizers)}.")
else:
raise ValueError("Unsupported input type. Please provide either a string or a torch.Tensor.")
def decode(self, input, skip_special_tokens=True):
return self.textTokenizer.decode(input, skip_special_tokens=skip_special_tokens)