Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datasets/activations_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ class ActivationDataset(BaseDataset):
"""
Dataset for preparing model activations
"""
def __init__(self, activations, labels):
def __init__(self, activations, labels, filenames):
self.audio_data = self.prepare_data(activations)
self.labels = torch.tensor(labels, dtype=torch.long)
self.filenames = [str(path) for path in filenames]

def prepare_data(self, activations):
activations = [act.clone() for act in activations]
Expand Down
5 changes: 4 additions & 1 deletion datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ class BaseDataset(Dataset, abc.ABC):
def __init__(self):
self.audio_data = None
self.labels = None
self.filenames = None

def __getitem__(self, idx):
if self.audio_data is None:
raise NotImplementedError("self.audio_data is not initialized")
if self.labels is None:
raise NotImplementedError("self.labels is not initialized")
return self.audio_data[idx], self.labels[idx]
if self.filenames is None:
raise NotImplementedError("self.filenames is not initialized")
return self.audio_data[idx], self.labels[idx], self.filenames[idx]

def __len__(self):
return len(self.labels)
Expand Down
18 changes: 14 additions & 4 deletions datasets/embeddings_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from .base_dataset import BaseDataset

import chromadb
Expand Down Expand Up @@ -37,10 +38,10 @@ def prepare_data(self):
the ChromaDB database using the get_chroma_embeddings method.
"""
if self.source_type == "npy":
audio_data, labels = self.get_npy_embeddings(
audio_data, labels, filenames = self.get_npy_embeddings(
self.source_path, self.split)
elif self.source_type == "chromadb":
audio_data, labels = self.get_chroma_embeddings(
audio_data, labels, filenames = self.get_chroma_embeddings(
self.source_path, self.split, self.collection_name)
else:
raise ValueError(
Expand All @@ -49,6 +50,7 @@ def prepare_data(self):

self.audio_data = torch.tensor(audio_data, dtype=torch.float32)
self.labels = torch.tensor(labels, dtype=torch.long)
self.filenames = filenames

def get_npy_embeddings(self, source_path, split):
"""
Expand All @@ -59,7 +61,11 @@ def get_npy_embeddings(self, source_path, split):
data = source[split]
embeddings = np.array([item['embedding'] for item in data])
labels = self.lb.fit_transform([item['label'] for item in data])
return embeddings, labels
filenames = [
'/'.join(Path(item['file_path']).parts[-2:])
for item in data
]
return embeddings, labels, filenames

def get_chroma_embeddings(self, source_path, split, collection_name):
"""
Expand All @@ -72,4 +78,8 @@ def get_chroma_embeddings(self, source_path, split, collection_name):
embeddings = np.array(results['embeddings'], dtype=np.float32)
labels = self.lb.fit_transform(
[item['label'] for item in results['metadatas']])
return embeddings, labels
filenames = [
'/'.join(Path(item['file_path']).parts[-2:])
for item in results['metadatas']
]
return embeddings, labels, filenames
23 changes: 21 additions & 2 deletions interpretability_scripts/embedding_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from utils import (
get_loaders,
save_emb_metrics,
save_embs_to_csv,
save_visualization,
evaluate_emb_model
)
Expand Down Expand Up @@ -43,6 +44,12 @@ def main():
default="./result/gender.png",
help="Save path for embeddings visualisation"
)
parser.add_argument(
"--csv_path",
type=str,
default="./result/embeddings.csv",
help="Save path for embeddings csv"
)
args = parser.parse_args()

if not os.path.exists(args.source_path):
Expand All @@ -60,12 +67,24 @@ def main():
train_emb_model(model, train_loader, optimizer,
criterion, num_epoch=300, device=device)

metrics = evaluate_emb_model(model, test_loader, device)
metrics, true_labels, pred_labels = evaluate_emb_model(
model,
test_loader,
device
)
save_emb_metrics(metrics, args.eval_path)
save_visualization(
coords = save_visualization(
model, test_dataset.audio_data.numpy(),
test_dataset.labels.numpy(), args.visual_path, device=device
)
data = {
"filenames": test_dataset.filenames,
"true_labels": true_labels,
"predictions": pred_labels,
"x_coord": [item[0] for item in coords],
"y_coord": [item[1] for item in coords]
}
save_embs_to_csv(data, args.csv_path)


if __name__ == '__main__':
Expand Down
14 changes: 9 additions & 5 deletions interpretability_scripts/probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def load_or_extract_acts(i, chunk, acts_model, device, layer, mode):
identity_file=f"{mode}_identity_{i}_{num}.pt")
activations.append(layer_acts[layer])

dataset = ActivationDataset(activations, labels)
dataset = ActivationDataset(activations, labels, chunk)

return dataset, activations, labels

Expand Down Expand Up @@ -116,18 +116,22 @@ def test(
save_tmp(labels, "test", f"tmp_labels_{i}.pt")

probing_model.eval()
y_pred_chunk, y_true_chunk = [], []
y_pred_chunk, y_true_chunk, filenames = [], [], []
with torch.no_grad():
for X_batch, y_batch in loader:
for X_batch, y_batch, path_batch in loader:
X_batch = X_batch.to(device)
outputs = probing_model(X_batch).cpu()
y_pred_chunk.extend(outputs.numpy())
y_true_chunk.extend(y_batch.numpy())
filenames.extend(
['/'.join(Path(filepath).parts[-2:]) for filepath in
path_batch]
)

for filepath, true_label, pred in zip(chunk, y_true_chunk,
for filepath, true_label, pred in zip(filenames, y_true_chunk,
y_pred_chunk):
chunk_rows.append({
"filename": os.path.relpath(filepath, args.test_dir),
"filename": str(filepath),
"true_label": true_label,
f"prediction_{layer}": int(pred > 0.5)
})
Expand Down
4 changes: 2 additions & 2 deletions models/train_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def train_emb_model(
for epoch in tqdm(range(num_epoch), desc="Training Progress"):
model.train()

for embeddings_batch, labels_batch in tqdm(
for embeddings_batch, labels_batch, _ in tqdm(
train_loader, desc=f"Epoch {epoch + 1}/{num_epoch}"
):
embeddings_batch = embeddings_batch.to(device)
Expand Down Expand Up @@ -49,7 +49,7 @@ def train_probing_model(

model.train()
for epoch in tqdm(range(num_epoch), desc="Training Progress"):
for X_batch, y_batch in train_loader:
for X_batch, y_batch, _ in train_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()

optimizer.zero_grad()
Expand Down
22 changes: 22 additions & 0 deletions results/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Здесь находятся датафреймы и визуализации для каждого топика

## Структура датафрейма по *layer analysis*:
- название файла
- лейбл
- предикт слоя 1
- предикт слоя 2
- и т. д.

## Структура датафрейма по *embedding analysis*:
- название файла
- лейбл
- координата по оси x
- координата по оси y

## Структура датафрейма по *embedding analysis с предиктами*:
- название файла
- предикт

## Правила наименования файлов:
Называйте всё в следующем формате:
`(название интерпретируемой модели)_(задача интерпретации)_(что находится внутри).(расширение)`
Loading