Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/AnthropicHH-Bias/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Anthropic HH Bias Example

This folder shows **one minimal, self-contained example** of how to use
[`kronfluence`](https://github.com/pomonam/kronfluence) to:

1. Fit EKFAC influence *factors* on a large-language model (the 410 M parameter
Pythia model) using a subset of 10 k Anthropic-HH training samples.
2. Compute *pairwise* influence scores between that training set and the
"Stereotypical Bias" evaluation set.

The goal of the example is to be **copy-paste simple**: after installing the
requirements, a single command will produce both the factors and the influence
scores.

---

## 1 · Quick start

```bash
# (1) Create a fresh virtual environment – optional but recommended
python -m venv ekfac-venv # or conda env create -n ekfac python=3.10
source ekfac-venv/bin/activate

# (2) Install the few extra libraries this example needs
pip install -r requirements.txt

# (3) Download the model weights & fit EKFAC factors (≈ 10 min on a V100)
python fit_all_factors.py

# (4) Compute pairwise influence scores (≈ 5 min)
python compute_pairwise_scores.py

# Outputs ⇒ ./influence_results/
```

That's it – run the two scripts and you will obtain:

```
./influence_results/
├─ factors_ekfac_half/ # EKFAC blocks (state_dicts)
└─ scores_ekfac_half.npy # N×N influence matrix
```

---

## 2 · Folder layout

```
AnthropicHH-Bias/
├─ fit_all_factors.py # Step-1 · compute EKFAC factors
├─ compute_pairwise_scores.py # Step-2 · compute pairwise scores
├─ task.py # Loss + measurement definitions
├─ utils.py # Helper functions (dataset loading, metrics…)
├─ SFT_Trainer_Lora.py # Optional LoRA fine-tuning script
├─ requirements.txt # Extra runtime deps (torch + transformers …)
└─ README.md # ← this file
```

If you only care about influence scores you can ignore `SFT_Trainer_Lora.py` —
it shows how to do an *optional* LoRA fine-tuning pass before computing EKFAC.

---

## 3 · What happens under the hood?

1. **`fit_all_factors.py`**
• loads Pythia-410 M and selects the last transformer block(s) to track
• streams 10 k examples from Anthropic-HH and fits EKFAC statistics
• saves the factors below `./influence_results/factors_ekfac_half/`

2. **`compute_pairwise_scores.py`**
• reloads the same model and factors
• computes gradient similarities between every train example and every eval
example ("stereotypical bias" set)
• stores the final influence matrix `scores_ekfac_half.npy`

All hyper-parameters (batch sizes, half-precision flag, etc.) live at the top of
`fit_all_factors.py` so you only need to edit them **once**.

---

## 4 · Troubleshooting / FAQ

• `ImportError: kronfluence` → make sure the library is installed in the current
environment (`pip install kronfluence==1.0.1`).

• GPU OOM during factor fitting → lower `initial_per_device_batch_size_attempt`
in `fit_all_factors.py` (e.g. from *32* to *8*).

---
117 changes: 117 additions & 0 deletions examples/AnthropicHH-Bias/SFT_Trainer_Lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
import torch

from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
from trl import SFTTrainer, SFTConfig
from peft import get_peft_model, LoraConfig, TaskType

# ─── Configuration ─────────────────────────────────────────────────────────────
# Paths
os.environ["TRANSFORMERS_CACHE"] = ... # where to load the base model from..

model_name = "EleutherAI/pythia-410m"
output_dir = "pythia_410m_sft_hh_full_sft_trainer" # where to save the trained model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dataset = load_dataset("Dahoas/static-hh")

# LoRA hyperparameters
lora_r = 32
lora_alpha = 32
lora_dropout = 0.05

# Training hyperparameters
learning_rate = 1e-4
num_train_epochs = 3
max_length = 256
per_device_train_batch_size = 8
per_device_eval_batch_size = 8
gradient_accumulation_steps = 1

# ─── Tokenizer & Model Loading ────────────────────────────────────────────────

tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)
# Ensure pad token for causal LM
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map={"": device}, # Map all modules to cuda:0
# Use float32 to avoid mixed precision issues
torch_dtype=torch.float32,
cache_dir=os.environ["TRANSFORMERS_CACHE"]
)

# ─── Apply LoRA (PEFT) ─────────────────────────────────────────────────────────

peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=["q_proj", "k_proj", "v_proj", "dense"]
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


def format_dataset(examples):
"""Format the dataset for SFTTrainer with prompt-completion structure."""
return {
"prompt": examples["prompt"],
"completion": examples["chosen"]
}

train_dataset = dataset["train"].map(format_dataset, batched=True, remove_columns=dataset["train"].column_names)
eval_dataset = dataset["test"].map(format_dataset, batched=True, remove_columns=dataset["test"].column_names)

# ─── Training Setup ───────────────────────────────────────────────────────────

training_args = SFTConfig(
output_dir=output_dir,
learning_rate=learning_rate,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
max_seq_length=max_length,
num_train_epochs=num_train_epochs,
weight_decay=0.01,
logging_steps=1000,
# Disable fp16 to avoid mixed precision conflicts
fp16=False,
bf16=False,
push_to_hub=True,
load_best_model_at_end=True,
max_steps=-1,
# Add evaluation and save strategies
eval_strategy="steps",
save_strategy="steps",
eval_steps=1000,
save_steps=10000,
save_total_limit=3, # Keep only the last 3 checkpoints
report_to=["wandb"],
hub_token= ... # your huggingface token
)

trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
)

# ─── Run Training ─────────────────────────────────────────────────────────────

trainer.train()
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Model and tokenizer saved to {output_dir}")
Empty file.
55 changes: 55 additions & 0 deletions examples/AnthropicHH-Bias/compute_pairwise_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
'''
Implements computation of the pairwise scores for the pythia 410m model.
'''
import torch
from kronfluence.arguments import ScoreArguments
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
from fit_all_factors import (
factor_args,
USE_HALF_PRECISION,
USE_COMPILE,
COMPUTE_PER_TOKEN_SCORES,
QUERY_GRADIENT_RANK,
analyzer,
factors_name
)
from utils import get_anthropic_dataset, get_bias_agreement_dataset
from task import tokenizer

# Compute pairwise scores.

score_args = ScoreArguments()
scores_name = factor_args.strategy
if USE_HALF_PRECISION:
score_args = all_low_precision_score_arguments(dtype=torch.bfloat16)
scores_name += "_half"
if USE_COMPILE:
scores_name += "_compile"
if COMPUTE_PER_TOKEN_SCORES:
score_args.compute_per_token_scores = True
scores_name += "_per_token"
rank = QUERY_GRADIENT_RANK if QUERY_GRADIENT_RANK != -1 else None
if rank is not None:
score_args.query_gradient_low_rank = rank
score_args.query_gradient_accumulation_steps = 10
scores_name += f"_qlr{rank}"

score_args.aggregate_query_gradients=True # False by default. Highly recommend running with True.

anthropic_dataset = get_anthropic_dataset(tokenizer)

bias_dataset = get_bias_agreement_dataset(tokenizer)


QUERY_BATCH_SIZE = 16
TRAIN_BATCH_SIZE = 16
analyzer.compute_pairwise_scores(
scores_name=scores_name,
score_args=score_args,
factors_name=factors_name,
query_dataset=bias_dataset,
train_dataset=anthropic_dataset,
per_device_query_batch_size=QUERY_BATCH_SIZE,
per_device_train_batch_size=TRAIN_BATCH_SIZE,
overwrite_output_dir=False
)
78 changes: 78 additions & 0 deletions examples/AnthropicHH-Bias/fit_all_factors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
'''
The purpose of this script is to fit the factors for the pythia 410m model.
'''

import os
import logging
import numpy as np
import torch
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments
from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments
from kronfluence.utils.dataset import DataLoaderKwargs
from transformers import default_data_collator
from utils import get_anthropic_dataset
from termcolor import colored
from task import (
BiasTask,
model,
tokenizer
)

CHECKPOINT_DIR = "./checkpoints"
FACTOR_STRATEGY = "ekfac"
QUERY_GRADIENT_RANK = -1
USE_HALF_PRECISION = True
USE_COMPILE = False
QUERY_BATCH_SIZE = 16
TRAIN_BATCH_SIZE = 16
PROFILE = False
COMPUTE_PER_TOKEN_SCORES = False
SIZE_OF_DATASET = 10000 # 10k samples from the dataset to estimate the true fisher information matrix
# You can change how many modules to consider for the computation of ekfac factors in the task.py file varying the number of transformer blocks to consier (NUM_BLOCKS)

print(colored("Loading model and dataset...", "green"))


anthropic_dataset = get_anthropic_dataset(tokenizer)
anthropic_dataset = anthropic_dataset.select(range(SIZE_OF_DATASET))
anthropic_indices = np.arange(len(anthropic_dataset)).astype(int) # -> Python ints
anthropic_dataset = anthropic_dataset.select(anthropic_indices) # stays an HF Dataset

if os.path.isfile(os.path.join(CHECKPOINT_DIR, "model.pth")):
print(colored("Checkpoint found", "green"))
logging.info(f"Loading checkpoint from {CHECKPOINT_DIR}")
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "model.pth")))

# Define task and prepare model.
task = BiasTask()
model = prepare_model(model, task)
if USE_COMPILE:
model = torch.compile(model)
analyzer = Analyzer(
analysis_name="bias_pythia_410m",
model=model,
task=task,
profile=PROFILE,
)

# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(collate_fn=default_data_collator)
analyzer.set_dataloader_kwargs(dataloader_kwargs)
# Compute influence factors.
factors_name = FACTOR_STRATEGY
factor_args = FactorArguments(strategy=FACTOR_STRATEGY, use_empirical_fisher=True) # use empirical fisher is false by default.
if USE_HALF_PRECISION:
factor_args = all_low_precision_factor_arguments(strategy=FACTOR_STRATEGY, dtype=torch.bfloat16)
factors_name += "_half"
if USE_COMPILE:
factors_name += "_compile"

analyzer.fit_all_factors(
factors_name=factors_name,
dataset=anthropic_dataset,
per_device_batch_size=None,
factor_args=factor_args,
initial_per_device_batch_size_attempt=32,
overwrite_output_dir=False,
)
7 changes: 7 additions & 0 deletions examples/AnthropicHH-Bias/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
torch
transformers
termcolor
trl
numpy
peft
datasets
22 changes: 22 additions & 0 deletions examples/AnthropicHH-Bias/results.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Bias-loss results

Lower is better (negative log-likelihood on the Stereotypical Bias dev set).

| Model / Training subset | Bias NLL |
|-------------------------|---------:|
| Base Pythia-410M | 0.513 |
| SFT on full HH (96 k examples) | 0.4871 |
| SFT on random 45 k examples | 0.4987 |
| SFT on **lowest** 45 k EKFAC-ranked examples | 0.5965 |
| SFT on **highest** 45 k EKFAC-ranked examples | **0.3727** |

---

## LoRA adapter checkpoints on Huggingface

```text
ncgc/pythia_410m_hh_full_sft_trainer
ncgc/pythia_410m_sft_hh_random_45k
ncgc/pythia_410m_sft_hh_45k_lowest.bias
ncgc/pythia_410m_sft_hh_45k_highest.bias
```
Loading
Loading