Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions notebooks/AmazonBeautyDatasetStatistics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -419,7 +419,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
64 changes: 64 additions & 0 deletions scripts/plum-yambda/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch

import irec.callbacks as cb
from irec.runners import TrainingRunner, TrainingRunnerContext

class InitCodebooks(cb.TrainingCallback):
def __init__(self, dataloader):
super().__init__()
self._dataloader = dataloader

@torch.no_grad()
def before_run(self, runner: TrainingRunner):
for i in range(len(runner.model.codebooks)):
X = next(iter(self._dataloader))['embedding']
idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])]
remainder = runner.model.encoder(X[idx])

for j in range(i):
codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j])
codebook_vectors = runner.model.codebooks[j][codebook_indices]
remainder = remainder - codebook_vectors

runner.model.codebooks[i].data = remainder.detach()


class FixDeadCentroids(cb.TrainingCallback):
def __init__(self, dataloader):
super().__init__()
self._dataloader = dataloader

def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext):
for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)):
context.metrics[f'num_dead/{i}'] = num_fixed

@torch.no_grad()
def fix_dead_codebooks(self, runner: TrainingRunner):
num_fixed = []
for codebook_idx, codebook in enumerate(runner.model.codebooks):
centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device)
random_batch = next(iter(self._dataloader))['embedding']

for batch in self._dataloader:
remainder = runner.model.encoder(batch['embedding'])
for l in range(codebook_idx):
ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
remainder = remainder - runner.model.codebooks[l][ind]

indices = runner.model.get_codebook_indices(remainder, codebook)
centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))

dead_mask = (centroid_counts == 0)
num_dead = int(dead_mask.sum().item())
num_fixed.append(num_dead)
if num_dead == 0:
continue

remainder = runner.model.encoder(random_batch)
for l in range(codebook_idx):
ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
remainder = remainder - runner.model.codebooks[l][ind]
remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
codebook[dead_mask] = remainder.detach()

return num_fixed
108 changes: 108 additions & 0 deletions scripts/plum-yambda/cooc_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json
import pickle
from collections import defaultdict, Counter

import numpy as np
from loguru import logger


import pickle
from collections import defaultdict, Counter

class CoocMappingDataset:
def __init__(
self,
train_sampler,
num_items,
cooccur_counter_mapping=None
):
self._train_sampler = train_sampler
self._num_items = num_items
self._cooccur_counter_mapping = cooccur_counter_mapping

@classmethod
def create(cls, inter_json_path, window_size):
max_item_id = 0
train_dataset, validation_dataset, test_dataset = [], [], []

with open(inter_json_path, 'r') as f:
user_interactions = json.load(f)

for user_id_str, item_ids in user_interactions.items():
user_id = int(user_id_str)
if item_ids:
max_item_id = max(max_item_id, max(item_ids))
assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items'
train_dataset.append({
'user.ids': [user_id],
'item.ids': item_ids[:-2],
})

cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')

train_sampler = train_dataset

return cls(
train_sampler=train_sampler,
num_items=max_item_id + 1,
cooccur_counter_mapping=cooccur_counter_mapping
)

@classmethod
def create_from_split_part(
cls,
train_inter_json_path,
window_size
):

max_item_id = 0
train_dataset = []

with open(train_inter_json_path, 'r') as f:
train_interactions = json.load(f)

# Обрабатываем TRAIN
for user_id_str, item_ids in train_interactions.items():
user_id = int(user_id_str)
if item_ids:
max_item_id = max(max_item_id, max(item_ids))

train_dataset.append({
'user.ids': [user_id],
'item.ids': item_ids,
})

logger.debug(f'Train: {len(train_dataset)} users')
logger.debug(f'Max item ID: {max_item_id}')

cooccur_counter_mapping = cls.build_cooccur_counter_mapping(
train_dataset,
window_size=window_size
)

logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items')

return cls(
train_sampler=train_dataset,
num_items=max_item_id + 1,
cooccur_counter_mapping=cooccur_counter_mapping
)


@staticmethod
def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно
cooccur_counts = defaultdict(Counter)
for session in train_dataset:
items = session['item.ids']
for i in range(len(items)):
item_i = items[i]
for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)):
if i != j:
cooccur_counts[item_i][items[j]] += 1
return cooccur_counts


@property
def cooccur_counter_mapping(self):
return self._cooccur_counter_mapping
62 changes: 62 additions & 0 deletions scripts/plum-yambda/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np
import pickle

from irec.data.base import BaseDataset
from irec.data.transforms import Transform


import polars as pl
import numpy as np
import torch

class EmbeddingDatasetParquet(BaseDataset):
def __init__(self, data_path):
self.df = pl.read_parquet(data_path)
self.item_ids = np.array(self.df['item_id'], dtype=np.int64)
self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32)
print(f"embedding dim: {self.embeddings[0].shape}")

def __getitem__(self, idx):
index = self.item_ids[idx]
tensor_emb = self.embeddings[idx]
return {
'item_id': index,
'embedding': tensor_emb,
'embedding_dim': len(tensor_emb)
}

def __len__(self):
return len(self.embeddings)


class EmbeddingDataset(BaseDataset):
def __init__(self, data_path):
self.data_path = data_path
with open(data_path, 'rb') as f:
self.data = pickle.load(f)

self.item_ids = np.array(self.data['item_id'], dtype=np.int64)
self.embeddings = np.array(self.data['embedding'], dtype=np.float32)

def __getitem__(self, idx):
index = self.item_ids[idx]
tensor_emb = self.embeddings[idx]
return {
'item_id': index,
'embedding': tensor_emb,
'embedding_dim': len(tensor_emb)
}

def __len__(self):
return len(self.embeddings)


class ProcessEmbeddings(Transform):
def __init__(self, embedding_dim, keys):
self.embedding_dim = embedding_dim
self.keys = keys

def __call__(self, batch):
for key in self.keys:
batch[key] = batch[key].reshape(-1, self.embedding_dim)
return batch
135 changes: 135 additions & 0 deletions scripts/plum-yambda/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class PlumRQVAE(nn.Module):
def __init__(
self,
input_dim,
num_codebooks,
codebook_size,
embedding_dim,
beta=0.25,
quant_loss_weight=1.0,
contrastive_loss_weight=1.0,
temperature=0.0,
):
super().__init__()
self.register_buffer('beta', torch.tensor(beta))
self.temperature = temperature

self.input_dim = input_dim
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.embedding_dim = embedding_dim
self.quant_loss_weight = quant_loss_weight

self.contrastive_loss_weight = contrastive_loss_weight

self.encoder = self.make_encoding_tower(input_dim, embedding_dim)
self.decoder = self.make_encoding_tower(embedding_dim, input_dim)

self.codebooks = torch.nn.ParameterList()
for _ in range(num_codebooks):
cb = torch.FloatTensor(codebook_size, embedding_dim)
#nn.init.normal_(cb)
self.codebooks.append(cb)

@staticmethod
def make_encoding_tower(d1, d2, bias=False):
return torch.nn.Sequential(
nn.Linear(d1, d1),
nn.ReLU(),
nn.Linear(d1, d2),
nn.ReLU(),
nn.Linear(d2, d2, bias=bias)
)

@staticmethod
def get_codebook_indices(remainder, codebook):
dist = torch.cdist(remainder, codebook)
return dist.argmin(dim=-1)

def _quantize_representation(self, latent_vector):
latent_restored = 0
remainder = latent_vector

for codebook in self.codebooks:
codebook_indices = self.get_codebook_indices(remainder, codebook)
quantized = codebook[codebook_indices]
codebook_vectors = remainder + (quantized - remainder).detach()
latent_restored += codebook_vectors
remainder = remainder - codebook_vectors

return latent_restored

def contrastive_loss(self, p_i, p_i_star):
N_b = p_i.size(0)

p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза
p_i_star = F.normalize(p_i_star, p=2, dim=-1)

similarities = torch.matmul(p_i, p_i_star.T) / self.temperature

labels = torch.arange(N_b, dtype=torch.long, device=p_i.device)

loss = F.cross_entropy(similarities, labels)

return loss #только по последней размерности

def forward(self, inputs):
latent_vector = self.encoder(inputs['embedding'])
# print(f"latent vector shape: {latent_vector.shape}")
# print(f"inputs embedding shape: {inputs['embedding']}")
item_ids = inputs['item_id']

latent_restored = 0
rqvae_loss = 0
clusters = []
remainder = latent_vector

for codebook in self.codebooks:
codebook_indices = self.get_codebook_indices(remainder, codebook)
clusters.append(codebook_indices)

quantized = codebook[codebook_indices]
codebook_vectors = remainder + (quantized - remainder).detach()

rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach())
rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach())

latent_restored += codebook_vectors
remainder = remainder - codebook_vectors

embeddings_restored = self.decoder(latent_restored)
recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding'])

if 'cooccurrence_embedding' in inputs:
# print(f"cooccurrence_embedding shape: {inputs['cooccurrence_embedding'].shape} device {inputs['cooccurrence_embedding'].device}" )
# print(f"latent_restored shape {latent_restored.shape} device {latent_restored.device}")
cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device))
cooccurrence_restored = self._quantize_representation(cooccurrence_latent)
con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored)
else:
con_loss = torch.as_tensor(0.0, device=latent_vector.device)

loss = (
recon_loss
+ self.quant_loss_weight * rqvae_loss
+ self.contrastive_loss_weight * con_loss
).mean()

clusters_counts = []
for cluster in clusters:
clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size))

return loss, {
'loss': loss.item(),
'recon_loss': recon_loss.mean().item(),
'rqvae_loss': rqvae_loss.mean().item(),
'con_loss': con_loss.item(),

'clusters_counts': clusters_counts,
'clusters': torch.stack(clusters).T,
'embedding_hat': embeddings_restored,
}
Loading