Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions docs/generate_sae_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path

import pandas as pd
import yaml
from tqdm import tqdm

from sae_lens import SAEConfig
Expand All @@ -11,6 +10,7 @@
get_sae_config,
handle_config_defaulting,
)
from sae_lens.toolkit.pretrained_saes_directory import load_pretrained_saes_yaml

INCLUDED_CFG = [
"id",
Expand All @@ -32,10 +32,7 @@ def on_pre_build(config):


def generate_sae_table():
# Read the YAML file
yaml_path = Path("sae_lens/pretrained_saes.yaml")
with open(yaml_path, "r") as file:
data = yaml.safe_load(file)
data = load_pretrained_saes_yaml()

# Start the Markdown content
markdown_content = "# Pretrained SAEs\n\n"
Expand Down
73 changes: 36 additions & 37 deletions sae_lens/toolkit/pretrained_saes_directory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from functools import cache
from importlib import resources
from typing import Optional
from typing import Any, Optional

import yaml

Expand All @@ -19,39 +19,40 @@ class PretrainedSAELookup:
config_overrides: dict[str, str] | dict[str, dict[str, str | bool | int]] | None


@cache
def load_pretrained_saes_yaml() -> dict[str, Any]:
with resources.open_text("sae_lens", "pretrained_saes.yaml") as file:
return yaml.safe_load(file)


@cache
def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
package = "sae_lens"
# Access the file within the package using importlib.resources
directory: dict[str, PretrainedSAELookup] = {}
with resources.open_text(package, "pretrained_saes.yaml") as file:
# Load the YAML file content
data = yaml.safe_load(file)
for release, value in data.items():
saes_map: dict[str, str] = {}
var_explained_map: dict[str, float] = {}
l0_map: dict[str, float] = {}
neuronpedia_id_map: dict[str, str] = {}

assert "saes" in value, f"Missing 'saes' key in {release}"
for hook_info in value["saes"]:
saes_map[hook_info["id"]] = hook_info["path"]
var_explained_map[hook_info["id"]] = hook_info.get(
"variance_explained", 1.00
)
l0_map[hook_info["id"]] = hook_info.get("l0", 0.00)
neuronpedia_id_map[hook_info["id"]] = hook_info.get("neuronpedia")
directory[release] = PretrainedSAELookup(
release=release,
repo_id=value["repo_id"],
model=value["model"],
conversion_func=value.get("conversion_func"),
saes_map=saes_map,
expected_var_explained=var_explained_map,
expected_l0=l0_map,
neuronpedia_id=neuronpedia_id_map,
config_overrides=value.get("config_overrides"),
data = load_pretrained_saes_yaml()
for release, value in data.items():
saes_map: dict[str, str] = {}
var_explained_map: dict[str, float] = {}
l0_map: dict[str, float] = {}
neuronpedia_id_map: dict[str, str] = {}
assert "saes" in value, f"Missing 'saes' key in {release}"
for hook_info in value["saes"]:
saes_map[hook_info["id"]] = hook_info["path"]
var_explained_map[hook_info["id"]] = hook_info.get(
"variance_explained", 1.00
)
l0_map[hook_info["id"]] = hook_info.get("l0", 0.00)
neuronpedia_id_map[hook_info["id"]] = hook_info.get("neuronpedia")
directory[release] = PretrainedSAELookup(
release=release,
repo_id=value["repo_id"],
model=value["model"],
conversion_func=value.get("conversion_func"),
saes_map=saes_map,
expected_var_explained=var_explained_map,
expected_l0=l0_map,
neuronpedia_id=neuronpedia_id_map,
config_overrides=value.get("config_overrides"),
)
return directory


Expand All @@ -66,13 +67,11 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> Optional[float]:
Returns:
Optional[float]: The norm_scaling_factor if it exists, None otherwise.
"""
package = "sae_lens"
with resources.open_text(package, "pretrained_saes.yaml") as file:
data = yaml.safe_load(file)
if release in data:
for sae_info in data[release]["saes"]:
if sae_info["id"] == sae_id:
return sae_info.get("norm_scaling_factor")
data = load_pretrained_saes_yaml()
if release in data:
for sae_info in data[release]["saes"]:
if sae_info["id"] == sae_id:
return sae_info.get("norm_scaling_factor")
return None


Expand Down