Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def run_rome(cfg: DictConfig | argparse.Namespace) -> None:

# ROME success test
prompt = handler.tokenize_prompt(fact_tuple[0].format(fact_tuple[1]))
target_token_count = int(handler.tokenize_prompt(fact_tuple[2]).input_ids.shape[1])
outputs = handler.model.generate(
**prompt,
max_length=prompt.input_ids.shape[1] + len(handler.tokenize_prompt(f" {fact_tuple[2]}")[0]) - 1,
max_length=prompt.input_ids.shape[1] + target_token_count,
)
print(handler.tokenizer.batch_decode(outputs))

Expand Down
2 changes: 1 addition & 1 deletion src/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,6 @@ def tokenize_prompt(self, prompt_text: str | List[str], apply_template: bool = F
inputs = self.tokenizer(prompt_text, return_tensors="pt")


inputs = self.device_manager.safe_to_device(inputs)
inputs = self.device_manager.safe_to_device(inputs, device=self.device)

return inputs
26 changes: 23 additions & 3 deletions src/handlers/rome.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,18 @@ def __init__(self, cfg: DictConfig) -> None:
self.dtype = self.model.dtype
self.num_of_layers = self.model.config.num_hidden_layers

# Multi-GPU: detect if model is distributed via device_map
self.is_multi_gpu = hasattr(self.model, 'hf_device_map') and len(self.model.hf_device_map) > 1

# Initialize DeviceManager for CUDA-safe operations
device = getattr(cfg.model, "device", "cuda")
cuda_mode = getattr(cfg.model, "cuda_mode", CUDAMode.SOFT)
self.device_manager = DeviceManager(device, cuda_mode)
self.device = self.device_manager.get_device()
if self.is_multi_gpu:
# Use a concrete CUDA device (e.g., cuda:0) for model inputs.
self.device = next(self.model.parameters()).device
else:
self.device = self.device_manager.get_device()

self.batch_size = getattr(self.cfg.generation, "batch_size", 1) if hasattr(self.cfg, "generation") else 1

Expand Down Expand Up @@ -96,7 +103,7 @@ def __init__(self, cfg: DictConfig) -> None:
self.v = None
# Use device_manager for safe device placement
self.delta = torch.zeros((self.emb_shape), dtype=self.dtype)
self.delta = self.device_manager.safe_to_device(self.delta).requires_grad_(True)
self.delta = self.device_manager.safe_to_device(self.delta, device=self.device).requires_grad_(True)

self.second_moment_path = getattr(cfg.model, "second_moment_path", None)

Expand Down Expand Up @@ -197,7 +204,7 @@ def remove_hooks(self) -> None:
self._emb_accumulator = []

self.delta = torch.zeros((self.emb_shape))
self.delta = self.device_manager.safe_to_device(self.delta).requires_grad_(True)
self.delta = self.device_manager.safe_to_device(self.delta, device=self.device).requires_grad_(True)
for handle in self._hooks:
handle.remove()

Expand All @@ -213,6 +220,19 @@ def _get_module(self, module_name: str) -> torch.nn.Module:

raise KeyError(f"{module_name} not found")

def get_module_device(self, module_name: str = None) -> torch.device:
"""Return the device a specific module's parameters live on.
Useful for multi-GPU setups where layers are on different devices.
Falls back to ``self.device`` if the module has no parameters.
"""
if module_name is None:
module_name = self._layer_name_template.format(self._layer)
module = self._get_module(module_name)
try:
return next(module.parameters()).device
except StopIteration:
return torch.device(self.device)

def register_casual_hooks(self) -> None:
"""
"""
Expand Down
132 changes: 99 additions & 33 deletions src/rome/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,23 @@
:author: Jakub Res <iresj@fit.vut.cz>
"""

from __future__ import annotations

from pathlib import Path
import copy
import re
import random
import json
import torch
from typing import Tuple, List
from typing import Tuple, List, TYPE_CHECKING
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from enum import Enum
import numpy as np

if TYPE_CHECKING:
from src.handlers.rome import ModelHandler


import logging
LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -380,22 +385,27 @@ def gather_k(
prompts = handler.tokenize_prompt(templates)

prompt_count = int(prompts.input_ids.shape[0])
batch_idx = torch.arange(prompt_count, device=prompts.input_ids.device)
index = (prompts.attention_mask[batch_idx].sum(dim=1) - 1).long()
token_index = (prompts.attention_mask.detach().to("cpu").sum(dim=1) - 1).long()

# TODO: Add support for dynamic batch size
k = None
def k_hook(_, input):
nonlocal k
# Pair each prompt with its own last non-padding token index.
k = input[0][batch_idx, index, :].mean(dim=0)
local_batch_idx = torch.arange(prompt_count, device=input[0].device)
local_index = token_index.to(input[0].device)
k = input[0][local_batch_idx, local_index, :].mean(dim=0)
return input

handler.set_k_hook(k_hook)
handler.model(**prompts)
handler.remove_hooks()

return handler.device_manager.safe_to_device(k)
if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu:
target_device = handler.get_module_device(handler._layer_name_template.format(handler._layer))
else:
target_device = handler.device
return handler.device_manager.safe_to_device(k, device=target_device)


# https://medium.com/biased-algorithms/all-pairs-cosine-similarity-in-pytorch-064f9875d531
Expand Down Expand Up @@ -446,7 +456,8 @@ def get_subject_position(handler, prompt, subject):

def get_subject_index(handler, prompts, fact_tuple, subject_understanding_template) -> torch.Tensor | None:
new_target_ids = _strip_bos(handler, handler.tokenize_prompt(fact_tuple[2])["input_ids"][0])
last_subject_index = (prompts.attention_mask[torch.arange(prompts.input_ids.shape[0])].sum(dim=1))
batch_idx = torch.arange(prompts.input_ids.shape[0], device=prompts.attention_mask.device)
last_subject_index = prompts.attention_mask[batch_idx].sum(dim=1)

fact_prompt = handler.tokenize_prompt(fact_tuple[0].format(fact_tuple[1]))
u_fact_prompt = handler.tokenize_prompt(subject_understanding_template.format(fact_tuple[1]))
Expand All @@ -469,7 +480,7 @@ def get_subject_index(handler, prompts, fact_tuple, subject_understanding_templa
u_sub_reverse_pos = len(u_fact_prompt["input_ids"][0]) - pos
last_subject_index[-1] -= u_sub_reverse_pos

return last_subject_index
return last_subject_index.long().cpu()

def optimize_v(
handler,
Expand Down Expand Up @@ -506,11 +517,18 @@ def optimize_v(
if last_subject_index is None:
LOGGER.error("Subject index computation failed during v computation.")
return None
last_subject_index_list = [int(x) for x in last_subject_index.tolist()]

layer_name = handler._layer_name_template.format(handler._layer)
if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu:
layer_device = handler.get_module_device(layer_name)
else:
layer_device = handler.device

# The optimizer setup
# Create delta on CPU first, then move through device_manager for tracking
delta = torch.zeros((handler.emb_shape), requires_grad=False, dtype=handler.dtype)
delta = handler.device_manager.safe_to_device(delta).requires_grad_(True)
delta = handler.device_manager.safe_to_device(delta, device=layer_device).requires_grad_(True)

opt = torch.optim.Adam([delta], lr=handler.lr)

Expand All @@ -519,25 +537,24 @@ def delta_hook(module, _, output):
if module == handler._get_module(handler._layer_name_template.format(handler._layer)):
new_output = output.clone()
if v_init is None:
v_init = output[0, last_subject_index[0]].detach().clone()
for i, idx in enumerate(last_subject_index):
new_output[i, idx, :] = new_output[i, idx, :] + delta.to(output.dtype)
v_init = output[0, last_subject_index_list[0]].detach().clone()
for i, idx in enumerate(last_subject_index_list):
new_output[i, idx, :] = new_output[i, idx, :] + delta.to(device=output.device, dtype=output.dtype)
return new_output


# Create index for all the prompts and targets
target_len = int(new_target_ids.size(0))
prompt_device = prompts.input_ids.device
main_prompt_idx = torch.arange(N_prompts, device=prompt_device)
index_positions = (
prompts.attention_mask[:N_prompts].sum(dim=1).unsqueeze(1)
main_prompt_idx_cpu = torch.arange(N_prompts, dtype=torch.long)
index_positions_cpu = (
prompts.attention_mask[:N_prompts].detach().to("cpu").sum(dim=1).unsqueeze(1)
- target_len
+ torch.arange(target_len, device=prompt_device).unsqueeze(0)
+ torch.arange(target_len, dtype=torch.long).unsqueeze(0)
).long()

index_ids = new_target_ids.unsqueeze(0).repeat(N_prompts, 1)
dkl_prompt_idx = torch.arange(N_prompts, prompts.input_ids.shape[0], device=prompt_device)
dkl_index = (prompts.attention_mask[dkl_prompt_idx].sum(dim=1) - 1).long()
index_ids_cpu = new_target_ids.detach().to("cpu").long().unsqueeze(0).repeat(N_prompts, 1)
dkl_prompt_idx_cpu = torch.arange(N_prompts, prompts.input_ids.shape[0], dtype=torch.long)
dkl_index_cpu = (prompts.attention_mask.detach().to("cpu")[dkl_prompt_idx_cpu].sum(dim=1) - 1).long()

for i in range(N_optim_steps):
opt.zero_grad()
Expand All @@ -547,6 +564,13 @@ def delta_hook(module, _, output):
outputs = handler.model(**prompts)
handler.remove_hooks()

logits_device = outputs.logits.device
main_prompt_idx = main_prompt_idx_cpu.to(logits_device)
index_positions = index_positions_cpu.to(logits_device)
index_ids = index_ids_cpu.to(logits_device)
dkl_prompt_idx = dkl_prompt_idx_cpu.to(logits_device)
dkl_index = dkl_index_cpu.to(logits_device)

all_log_probs = torch.log_softmax(outputs.logits, dim=2)
log_probs = all_log_probs[
main_prompt_idx.unsqueeze(1),
Expand Down Expand Up @@ -582,16 +606,25 @@ def delta_hook(module, _, output):

return delta

def insert_kv(handler, k: torch.Tensor, delta: torch.Tensor) -> None:
old_W = handler._get_module(handler._layer_name_template.format(handler._layer)).weight.clone() # extract from the model
def insert_kv(handler: ModelHandler, k: torch.Tensor, delta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
layer_name = handler._layer_name_template.format(handler._layer)
# For multi-GPU, use the device where this layer actually lives
if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu:
layer_device = handler.get_module_device(layer_name)
else:
layer_device = handler.device

old_W = handler._get_module(layer_name).weight.clone()

# Fix the transposed models
old_W_transposed = False
if old_W.shape[0] != k.shape[0]:
old_W = torch.transpose(old_W,0,1)
old_W_transposed = True

inv_cov = get_second_moment(handler).to(handler.dtype).to(handler.device)
inv_cov = get_second_moment(handler).to(handler.dtype).to(layer_device)
k = k.to(layer_device)
delta = delta.to(layer_device)
left = inv_cov @ k.unsqueeze(1)
left = left.squeeze()
left = left / left.norm()
Expand Down Expand Up @@ -627,7 +660,7 @@ def second_moment_wikipedia(handler, N_rounds, N_k):
Returns C^-1 (needed for ROME weight update formula)

"""
from src.utils import load_dataset
from src.utils import load_dataset, estimate_covariance_batch_size

layer_name = handler._layer_name_template.format(handler._layer)
module = handler._get_module(layer_name)
Expand All @@ -637,26 +670,54 @@ def second_moment_wikipedia(handler, N_rounds, N_k):
max_length = getattr(handler.model.config, 'n_positions',
getattr(handler.model.config, 'max_position_embeddings', 1024))

# For multi-GPU models, determine the device of the target module
if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu:
module_device = handler.get_module_device(layer_name)
else:
module_device = handler.device

# Accumulate second moment directly on GPU instead of storing all k vectors
C = torch.zeros(hidden_dim, hidden_dim, dtype=torch.float32, device=handler.device)
total_tokens = 0 # Use list to allow modification in hook
C = torch.zeros(hidden_dim, hidden_dim, dtype=torch.float32, device=module_device)
total_tokens = 0

def hook(_, inp, out):
nonlocal C, total_tokens
k = inp[0].detach().float() if isinstance(inp, tuple) else inp.detach().float()
if len(k.shape) == 3:
k = k.view(-1, k.shape[-1]) # [batch*seq, hidden]
# Ensure k is on the same device as C
k = k.to(C.device)
total_tokens += k.shape[0]
C.add_(k.T @ k)
return out

handle = module.register_forward_hook(hook)

n_samples = N_rounds * N_k if N_rounds and N_k else 5000
batch_size = 8 # Process multiple texts at once

# Dynamic batch size based on available VRAM
dtype_bytes = 2 if handler.dtype in (torch.float16, torch.bfloat16) else 4
batch_size = estimate_covariance_batch_size(
hidden_dim=hidden_dim,
max_length=max_length,
dtype_bytes=dtype_bytes,
device=module_device,
)

LOGGER.info(f"Starting covariance computation: {n_samples} samples, batch_size={batch_size}, max_length={max_length}")
ds = load_dataset(handler.cfg, sm=True)

# For multi-GPU models, place token inputs on the embedding module device.
if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu:
try:
input_module_name = handler._corrupt_layer_name_template
if "{}" in input_module_name:
input_module_name = input_module_name.format(0)
input_device = handler.get_module_device(input_module_name)
except Exception:
input_device = next(handler.model.parameters()).device
else:
input_device = handler.device

processed = 0
batch_texts = []
Expand All @@ -682,15 +743,20 @@ def hook(_, inp, out):
max_length=max_length,
padding=True
)
handler.model(tokens.input_ids.to(handler.device),
attention_mask=tokens.attention_mask.to(handler.device))
handler.model(tokens.input_ids.to(input_device),
attention_mask=tokens.attention_mask.to(input_device))
processed += len(batch_texts)
except torch.cuda.OutOfMemoryError:
LOGGER.warning("OOM during covariance computation, halving batch size")
if torch.cuda.is_available():
torch.cuda.empty_cache()
batch_size = max(1, batch_size // 2)
except Exception as e:
LOGGER.warning(e)
pass # Skip failed batches
batch_texts = []
# Clear GPU cache periodically
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Process remaining texts
if batch_texts and processed < n_samples:
Expand All @@ -702,10 +768,10 @@ def hook(_, inp, out):
max_length=max_length,
padding=True
)
handler.model(tokens.input_ids.to(handler.device),
attention_mask=tokens.attention_mask.to(handler.device))
handler.model(tokens.input_ids.to(input_device),
attention_mask=tokens.attention_mask.to(input_device))
processed += len(batch_texts)
except:
except Exception:
pass

handle.remove()
Expand Down
4 changes: 2 additions & 2 deletions src/rome/rome.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def batch_evaluation(cfg: DictConfig) -> None:
def single_intervention(handler: ModelHandler, fact_tuple: Tuple[str,str,str,str]) -> None:
k = gather_k(handler, fact_tuple=fact_tuple, N=getattr(handler.cfg.generation, 'k_N', 50))
delta = optimize_v(handler, fact_tuple, N_prompts=getattr(handler.cfg.generation, 'v_N', 50), N_optim_steps=handler.epochs)
new_W, old_W = insert_kv(handler, k, delta)
new_W, old_W, _ = insert_kv(handler, k, delta)

if handler.save_new_weights:
out_path = Path(handler.new_weights_dir) / f"{handler.cfg.model.name.replace('/', '-')}_{handler._layer}.pt"
Expand Down Expand Up @@ -189,7 +189,7 @@ def main(cfg: DictConfig) -> None:
delta = optimize_v(handler, k, fact_tuple, N_prompts=50, N_optim_steps=handler.epochs, epsilon=0.005)
LOGGER.info(f"delta computed, shape: {delta.shape}")

new_W, old_W = insert_kv(handler, k, delta)
new_W, old_W, _ = insert_kv(handler, k, delta)
LOGGER.info(f"New weights computed")

if handler.save_new_weights:
Expand Down
Loading