Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 140 additions & 39 deletions modeling/models/tiger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
activation='relu',
layer_norm_eps=1e-9,
initializer_range=0.02,
beam_width=1
):
super().__init__()

Expand All @@ -33,6 +34,7 @@ def __init__(
self._num_decoder_layers = num_decoder_layers
self._dim_feedforward = dim_feedforward
self._layer_norm_eps = layer_norm_eps
self._beam_width = beam_width

self._sem_id_len = 4

Expand Down Expand Up @@ -306,21 +308,21 @@ def forward(self, inputs):

batch_size = encoder_input_emb.size(0)

tgt = self.bos_embedding[None, None].tile(dims=[batch_size, 1, 1]) # (batch_size, 1, embedding_dim)
greedy_tgt = self.bos_embedding[None, None].tile(dims=[batch_size, 1, 1])

memory_key_padding_mask = ~after_encoder_mask

argmaxes = []
scores = []
losses = []
greedy_scores = []
greedy_losses = []
greedy_argmaxes = []

for step in range(self._sem_id_len):
tgt_mask = nn.Transformer.generate_square_subsequent_mask(
tgt.size(1), device=DEVICE
greedy_tgt.size(1), device=DEVICE
) # (L, L)

decoder_output = self._decoder(
tgt=tgt,
tgt=greedy_tgt,
memory=after_encoder_emb,
tgt_mask=tgt_mask,
memory_key_padding_mask=memory_key_padding_mask
Expand All @@ -329,13 +331,13 @@ def forward(self, inputs):
last_output = decoder_output[:, -1, :] # (batch_size, 1, embedding_dim)
weights = self.codebook_embeddings.weight[step * self._codebook_size: (step + 1) * self._codebook_size]
logits = last_output @ weights.T # (batch_size, codebook_size)
scores.append(logits)
greedy_scores.append(logits)

pred_tokens = torch.argmax(logits, dim=-1) # (batch_size,)
argmaxes.append(pred_tokens)
greedy_argmaxes.append(pred_tokens)

loss = nn.functional.cross_entropy(logits, target_tokens[:, step])
losses.append(loss)
greedy_losses.append(loss)

if step < self._sem_id_len - 1:
next_embed = self.codebook_embeddings(
Expand All @@ -346,55 +348,63 @@ def forward(self, inputs):
next_embed += pos_emb

next_embed = next_embed.unsqueeze(1) # (batch_size, 1, embedding_dim)
tgt = torch.cat([tgt, next_embed], dim=1)

all_items_semantic_ids = inputs['all_semantic_ids'] # (num_items, sid_length)
all_items_semantic_ids = all_items_semantic_ids + 256 * torch.arange(4, device=all_items_semantic_ids.device)
greedy_tgt = torch.cat([greedy_tgt, next_embed], dim=1)

if self._beam_width > 1:
beam_sequences = []
for i in range(batch_size):
memory_i = after_encoder_emb[i].unsqueeze(0)
memory_mask_i = after_encoder_mask[i].unsqueeze(0)
sequence, _ = self._beam_search(
memory_i, memory_mask_i, self._beam_width, self._sem_id_len
)
beam_sequences.append(sequence)

beam_sequences = torch.tensor(
beam_sequences, device=DEVICE, dtype=torch.long
)
argmaxes = [beam_sequences[:, i] for i in range(self._sem_id_len)]
else:
argmaxes = greedy_argmaxes

decoder_scores = torch.softmax(torch.stack(scores, dim=1) / torch.clip(torch.exp(self.scale), min=0.01, max=100), dim=-1)
decoder_scores = decoder_scores.reshape(decoder_scores.shape[0], decoder_scores.shape[1] * decoder_scores.shape[2])
all_items_semantic_ids = inputs['all_semantic_ids'] # (num_items, sid_length)
all_items_semantic_ids = all_items_semantic_ids + 256 * torch.arange(4,
device=all_items_semantic_ids.device)
# почему не clamp?
decoder_scores = torch.softmax(
torch.stack(greedy_scores, dim=1) / torch.clip(torch.exp(self.scale), min=0.01, max=100), dim=-1)
decoder_scores = decoder_scores.reshape(decoder_scores.shape[0],
decoder_scores.shape[1] * decoder_scores.shape[2])

all_items, id_dim = all_items_semantic_ids.shape
batch_indices = torch.arange(batch_size).unsqueeze(1).unsqueeze(2)
ids_expanded = all_items_semantic_ids.unsqueeze(0).expand(batch_size, -1, -1)

all_item_scores = decoder_scores[batch_indices.expand(-1, all_items, id_dim), ids_expanded] # (batch_size, num_items, sid_length)
all_item_scores = decoder_scores[
batch_indices.expand(-1, all_items, id_dim), ids_expanded] # (batch_size, num_items, sid_length)
all_item_scores = all_item_scores.prod(dim=-1)

sort_indices = torch.argsort(all_item_scores, dim=-1, descending=True, stable=True)
# import code; code.interact(local=locals())
# batch_size, num_items, sid_length = all_item_scores.shape

# indices = torch.arange(num_items, device=all_item_scores.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, 4)

# for i in range(sid_length - 1, -1, -1): # sid_length-1, sid_length-2, ..., 1, 0
# key_values = torch.gather(all_item_scores, dim=1, index=indices)[:, :, i]
# sort_indices = torch.argsort(key_values, dim=1, descending=True, stable=True).unsqueeze(-1).expand(batch_size, -1, 4)
# indices = torch.gather(indices, dim=1, index=sort_indices)

# print(indices[:2, :2])
# indices = indices[:, :, 0]


return {
"decoder_scores_1": scores[0],
"decoder_scores_2": scores[1],
"decoder_scores_3": scores[2],
"decoder_scores_4": scores[3],
"decoder_scores_1": greedy_scores[0],
"decoder_scores_2": greedy_scores[1],
"decoder_scores_3": greedy_scores[2],
"decoder_scores_4": greedy_scores[3],

"decoder_argmax_1": argmaxes[0],
"decoder_argmax_2": argmaxes[1],
"decoder_argmax_3": argmaxes[2],
"decoder_argmax_4": argmaxes[3],

"decoder_loss_1": losses[0], # (1, )
"decoder_loss_2": losses[1], # (1, )
"decoder_loss_3": losses[2], # (1, )
"decoder_loss_4": losses[3], # (1, )
"decoder_loss_1": greedy_losses[0], # (1, )
"decoder_loss_2": greedy_losses[1], # (1, )
"decoder_loss_3": greedy_losses[2], # (1, )
"decoder_loss_4": greedy_losses[3], # (1, )

"predictions": sort_indices,
"scale": torch.exp(self.scale).item(),
}

def _apply_encoder(
self,
embeddings, # (batch_size, max_seq_len, embedding_dim)
Expand All @@ -407,4 +417,95 @@ def _apply_encoder(
src=embeddings, src_key_padding_mask=~mask
) # (batch_size, seq_len, embedding_dim)

return embeddings, mask
return embeddings, mask

def _beam_search(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
beam_width: int,
max_len: int
):
"""
Perform beam search for a single example.

Args:
memory: Encoder output (1, mem_seq_len, embedding_dim)
memory_mask: Memory mask (1, mem_seq_len)
beam_width: Number of beams to maintain
max_len: Length of sequence to generate

Returns:
sequence: Generated token sequence
score: Final score of the sequence
"""
beams = [([], 0.0, self.bos_embedding[None, None])]

for step in range(max_len):
current_beam_size = len(beams)
tgt_embs = torch.cat([beam[2] for beam in beams], dim=0)

# Expand memory for current beams
memory_expanded = memory.expand(current_beam_size, -1, -1)
memory_mask_expanded = memory_mask.expand(current_beam_size, -1)

# Create target mask
tgt_mask = nn.Transformer.generate_square_subsequent_mask(
tgt_embs.size(1), device=DEVICE
)

# Run decoder
decoder_output = self._decoder(
tgt=tgt_embs,
memory=memory_expanded,
tgt_mask=tgt_mask,
memory_key_padding_mask=~memory_mask_expanded
)
last_output = decoder_output[:, -1, :]

# Calculate token probabilities
weights = self.codebook_embeddings.weight[
step * self._codebook_size: (step + 1) * self._codebook_size
]
logits = last_output @ weights.T
log_probs = torch.log_softmax(logits, dim=-1)

# Calculate new scores
prev_scores = torch.tensor(
[beam[1] for beam in beams], device=DEVICE
).unsqueeze(1)
new_scores = prev_scores + log_probs

# Flatten to select top candidates
new_scores_flat = new_scores.view(-1)
topk_scores, topk_indices = new_scores_flat.topk(beam_width, dim=0)

# Determine beam and token indices
beam_indices = topk_indices // self._codebook_size
token_indices = topk_indices % self._codebook_size

# Create new beams
new_beams = []
for i in range(beam_width):
beam_idx = beam_indices[i].item()
token = token_indices[i].item()
score = topk_scores[i].item()

old_beam = beams[beam_idx]
new_tokens = old_beam[0] + [token]

# Build next target embeddings
token_embed = self.codebook_embeddings(
step * self._codebook_size + token
)
pos_embed = self.sem_id_position_embeddings(
torch.tensor([step], device=DEVICE)
)
next_embed = (token_embed + pos_embed).unsqueeze(0).unsqueeze(0)
new_tgt_emb = torch.cat([old_beam[2], next_embed], dim=1)

new_beams.append((new_tokens, score, new_tgt_emb))

beams = new_beams

return beams[0][0], beams[0][1]
Loading